From 54b622322757d7d8df949b377132091ac05ca05c Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Sun, 9 Feb 2020 13:07:30 +0400 Subject: [PATCH 01/10] memoryview --- .../main/java/org/datavec/python/Python.java | 11 +++-- .../org/datavec/python/PythonExecutioner.java | 39 +++++++++++------ .../java/org/datavec/python/PythonObject.java | 43 +++++++++++++++++-- .../datavec/python/TestPythonExecutioner.java | 23 +++++++++- 4 files changed, 94 insertions(+), 22 deletions(-) diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java index 80d6643e8032..9dabbef2d330 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java @@ -144,6 +144,14 @@ public static PythonObject bytearrayType() { return attr("bytearray"); } + public static PythonObject memoryview(PythonObject pythonObject) { + return attr("memoryview").call(pythonObject); + } + + public static PythonObject memoryviewType() { + return attr("memoryview"); + } + public static PythonObject bytes(PythonObject pythonObject) { return attr("bytes").call(pythonObject); } @@ -250,9 +258,6 @@ public static void deleteNonMainContexts(){ public static void exec(String code)throws PythonException{ PythonExecutioner.exec(code); } - public static void exec(String code, PythonVariables inputs) throws PythonException{ - PythonExecutioner.exec(code, inputs); - } public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{ PythonExecutioner.exec(code, inputs, outputs); } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index a06e60e9845e..3a345a689e5f 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -20,23 +20,16 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; -import org.bytedeco.cpython.PyThreadState; -import org.bytedeco.javacpp.BytePointer; import org.bytedeco.numpy.global.numpy; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; import java.io.IOException; import java.io.InputStream; -import java.nio.ByteBuffer; import java.nio.charset.Charset; -import java.util.Arrays; -import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import static org.bytedeco.cpython.global.python.*; -import static org.bytedeco.cpython.global.python.PyThreadState_Get; import static org.datavec.python.Python.*; /** @@ -105,6 +98,7 @@ public class PythonExecutioner { init(); } + private static synchronized void init() { if (init.get()) { return; @@ -204,6 +198,9 @@ public static T getVariable(String varName, PythonType varType) throws Py } public static void getVariables(PythonVariables pyVars) throws PythonException { + if (pyVars == null){ + return; + } for (String varName : pyVars.getVariables()) { pyVars.setValue(varName, getVariable(varName, pyVars.getType(varName))); } @@ -240,12 +237,6 @@ public static void exec(String code) throws PythonException { throwIfExecutionFailed(); } - public static void exec(String code, PythonVariables outputVariables)throws PythonException { - simpleExec(getWrappedCode(code)); - throwIfExecutionFailed(); - getVariables(outputVariables); - } - public static void exec(String code, PythonVariables inputVariables, PythonVariables outputVariables) throws PythonException { setVariables(inputVariables); simpleExec(getWrappedCode(code)); @@ -354,7 +345,6 @@ public static synchronized void initPythonPath() { log.info("Setting python path " + path); StringBuffer sb = new StringBuffer(); File[] packages = numpy.cachePackages(); - JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase()); switch (pathAppendValue) { case BEFORE: @@ -395,4 +385,25 @@ public static synchronized void initPythonPath() { throw new IllegalStateException("Unable to reset python path. Already initialized."); } } + +// +// public static void installArrow(){ +// try{ +// exec("import pyarrow"); +// }catch (PythonException pe){ +// try{ +// String python = Loader.load(org.bytedeco.cpython.python.class); +// ProcessBuilder pb = new ProcessBuilder(python, "-m", "pip", "install", "pyarrow"); +// pb.inheritIO().start().waitFor(); +// }catch (Exception e){ +// log.warn("Unable to install pyarrow" + e); +// } +// try{ +// exec("import pyarrow"); +// } +// catch(Exception e){ +// log.warn("Installation failed for pyarrow" + e); +// } +// } +// } } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java index f1d54168bef9..c0079919c47b 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java @@ -69,7 +69,11 @@ public PythonObject(INDArray npArray) { } public PythonObject(BytePointer bp){ - nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity()); + + long address = bp.address(); + long size = bp.capacity(); + NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build(); + nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject; } public PythonObject(NumpyArray npArray) { @@ -343,13 +347,28 @@ public NumpyArray toNumpy() { dtype = DataType.DOUBLE; } else if (dtypeName.equals("float32")) { dtype = DataType.FLOAT; - } else if (dtypeName.equals("int16")) { + } else if (dtypeName.equals("int8")){ + dtype = DataType.INT8; + }else if (dtypeName.equals("int16")) { dtype = DataType.SHORT; } else if (dtypeName.equals("int32")) { dtype = DataType.INT; } else if (dtypeName.equals("int64")) { dtype = DataType.LONG; - } else { + } + else if (dtypeName.equals("uint8")){ + dtype = DataType.UINT8; + } + else if (dtypeName.equals("uint16")){ + dtype = DataType.UINT16; + } + else if (dtypeName.equals("uint32")){ + dtype = DataType.UINT32; + } + else if (dtypeName.equals("uint64")){ + dtype = DataType.UINT64; + } + else { throw new RuntimeException("Unsupported array type " + dtypeName + "."); } return new NumpyArray(address, jshape, jstrides, dtype); @@ -518,6 +537,22 @@ public BytePointer toBytePointer() throws PythonException{ else if (Python.isinstance(this, Python.bytearrayType())){ return PyByteArray_AsString(nativePythonObject); } + else if (Python.isinstance(this, Python.memoryviewType())){ + +// PyObject np = PyImport_ImportModule("numpy"); +// PyObject array = PyObject_GetAttrString(np, "asarray"); +// PyObject npArr = PyObject_CallObject(array, nativePythonObject); // Doesn't work + // Invoke interpreter: + String tempContext = "temp" + UUID.randomUUID().toString().replace('-', '_'); + String originalContext = Python.getCurrentContext(); + Python.setContext(tempContext); + PythonExecutioner.setVariable("memview", this); + PythonExecutioner.exec("import numpy as np\narr = np.array(memview)"); + BytePointer ret = new BytePointer(PythonExecutioner.getVariable("arr").toNumpy().getNd4jArray().data().pointer()); + Python.setContext(originalContext); + Python.deleteContext(tempContext); + return ret; + } else{ PyObject ctypes = PyImport_ImportModule("ctypes"); PyObject cArrType = PyObject_GetAttrString(ctypes, "Array"); @@ -542,7 +577,7 @@ else if (Python.isinstance(this, Python.bytearrayType())){ return new BytePointer(ptr); } else{ - throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString()); + throw new PythonException("Expected bytes, bytearray, memoryview or ctypesArray. Received " + Python.type(this).toString()); } } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java index bb436e808fbd..b8916476cab1 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -237,12 +238,13 @@ public void testByteBufferInput() throws Exception{ PythonVariables pyOutputs= new PythonVariables(); pyOutputs.addStr("out"); - String code = "out = buff.decode()"; + String code = "out = bytes(buff).decode()"; Python.exec(code, pyInputs, pyOutputs); Assert.assertEquals("abc", pyOutputs.getStrValue("out")); } + @Test public void testByteBufferOutputNoCopy() throws Exception{ INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); @@ -262,6 +264,24 @@ public void testByteBufferOutputNoCopy() throws Exception{ Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString()); } + @Test + public void testByteBufferInplace() throws Exception{ + INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); + buff.putScalar(0, 97); // a + buff.putScalar(1, 98); // b + buff.putScalar(2, 99); // c + PythonVariables pyInputs = new PythonVariables(); + pyInputs.addBytes("buff", new BytePointer(buff.data().pointer())); + String code = "buff[0]+=2\nbuff[2]-=2"; + Python.exec(code, pyInputs, null); + Assert.assertEquals("cba", pyInputs.getBytesValue("buff").getString()); + INDArray expected = buff.dup(); + expected.putScalar(0, 99); + expected.putScalar(2, 97); + Assert.assertEquals(buff, expected); + + } + @Test public void testByteBufferOutputWithCopy() throws Exception{ INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); @@ -302,4 +322,5 @@ public void testBadCode() throws Exception{ Python.setMainContext(); } + } From f03f37ff1e3d3eca9a2988510c088dd1a652e4f8 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Sun, 9 Feb 2020 13:10:39 +0400 Subject: [PATCH 02/10] cleanup --- .../org/datavec/python/PythonExecutioner.java | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index 3a345a689e5f..e2d2e57477da 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -385,25 +385,5 @@ public static synchronized void initPythonPath() { throw new IllegalStateException("Unable to reset python path. Already initialized."); } } - -// -// public static void installArrow(){ -// try{ -// exec("import pyarrow"); -// }catch (PythonException pe){ -// try{ -// String python = Loader.load(org.bytedeco.cpython.python.class); -// ProcessBuilder pb = new ProcessBuilder(python, "-m", "pip", "install", "pyarrow"); -// pb.inheritIO().start().waitFor(); -// }catch (Exception e){ -// log.warn("Unable to install pyarrow" + e); -// } -// try{ -// exec("import pyarrow"); -// } -// catch(Exception e){ -// log.warn("Installation failed for pyarrow" + e); -// } -// } -// } + } From 378c4875b012afcdba7a2e59b32abd11fae5f321 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Tue, 11 Feb 2020 22:54:50 +0400 Subject: [PATCH 03/10] array, field, schema working --- datavec/datavec-python/pom.xml | 16 +- .../main/java/org/datavec/python/Python.java | 16 ++ .../org/datavec/python/PythonArrowUtils.java | 193 ++++++++++++++++++ .../org/datavec/python/PythonExecutioner.java | 14 ++ .../datavec/python/TestPythonArrowUtils.java | 101 +++++++++ 5 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml index 55cf6c5dad1e..315aa94d0ed7 100644 --- a/datavec/datavec-python/pom.xml +++ b/datavec/datavec-python/pom.xml @@ -42,7 +42,16 @@ numpy-platform ${numpy.javacpp.version} - + + org.nd4j + nd4j-arrow + ${nd4j.version} + + + org.bytedeco + arrow-platform + ${arrow.javacpp.version} + com.google.code.findbugs jsr305 @@ -53,6 +62,11 @@ datavec-api ${project.version} + + org.datavec + datavec-arrow + ${project.version} + ch.qos.logback logback-classic diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java index 9dabbef2d330..d0d70b3b2438 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java @@ -112,6 +112,14 @@ public static PythonObject listType() { return attr("list"); } + public static PythonObject list(PythonObject[] pythonObjects){ + PyObject list = PyList_New(pythonObjects.length); + for(int i = 0;i < pythonObjects.length; i++){ + PyList_SetItem(list, i, pythonObjects[i].getNativePythonObject()); + } + return new PythonObject(list); + } + public static PythonObject dict(PythonObject pythonObject) { return attr("dict").call(pythonObject); } @@ -172,6 +180,14 @@ public static PythonObject tuple() { return attr("tuple").call(); } + public static PythonObject tuple(PythonObject[] pythonObjects){ + PyObject tuple = PyTuple_New(pythonObjects.length); + for(int i = 0;i < pythonObjects.length; i++){ + PyTuple_SetItem(tuple, i, pythonObjects[i].getNativePythonObject()); + } + return new PythonObject(tuple); + } + public static PythonObject Exception(PythonObject pythonObject) { return attr("Exception").call(pythonObject); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java new file mode 100644 index 000000000000..a610dd297bba --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java @@ -0,0 +1,193 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import org.bytedeco.arrow.DataType; +import org.bytedeco.arrow.Field; +import org.bytedeco.arrow.FieldVector; +import org.bytedeco.arrow.Schema; +import org.nd4j.linalg.api.ndarray.INDArray; + +import static org.bytedeco.arrow.global.arrow.*; + +public class PythonArrowUtils { + + static { + try{ + importPyArraow().del(); // ensures that we are loading pyarrow's binary, not javacpp's + }catch (Exception e){ + throw new RuntimeException(e); + } + } + public static PythonObject importPyArraow() throws PythonException{ + return Python.importModule("pyarrow"); + } + + public static PythonObject getPyArrowArrayFromINDArray(INDArray arr) throws PythonException{ + PythonObject pyarrow = importPyArraow(); + PythonObject npArr = new PythonObject(arr); + PythonObject arrayF = pyarrow.attr("array"); + PythonObject ret = arrayF.call(npArr); + pyarrow.del(); + npArr.del(); + arrayF.del(); + return ret; + } + + public static INDArray getINDArrayFromPyArrowArray(PythonObject arr) throws PythonException{ + PythonObject pyArrow = Python.importModule("pyarrow"); + PythonObject arrayType = pyArrow.attr("Array"); + if (!Python.isinstance(arr, arrayType)){ + pyArrow.del(); + arrayType.del(); + throw new PythonException("Expected pyarrow.Array, received " + Python.type(arr)); + } + PythonObject toNumpyF = arr.attr("to_numpy"); + PythonObject npArr = toNumpyF.call(); + pyArrow.del(); + arrayType.del(); + toNumpyF.del(); + INDArray ret = npArr.toNumpy().getNd4jArray(); + npArr.del(); + return ret; + } +// private static void installPyArrow() throws Exception{ +// try{ +// importPyArraow().del(); +// }catch (PythonException pe){ +// String python = Loader.load(org.bytedeco.cpython.python.class); +// ProcessBuilder pb = new ProcessBuilder(python, "-m", "pip", "install", "pyarrow==0.15.1"); +// pb.inheritIO().start().waitFor(); +// importPyArraow().del(); +// } +// } + + + public static PythonObject getPyArrowField(Field field) throws PythonException{ + String name = field.name(); + org.bytedeco.arrow.DataType type = field.type(); + String typeName = type.name(); + String pyTypeFName; + switch (typeName){ + case "list": + throw new PythonException("Unsupported field type: list"); + case "bool": + pyTypeFName = "bool_"; + break; + case "halffloat": + pyTypeFName = "float16"; + break; + case "float": + pyTypeFName = "float32"; + break; + case "double": + pyTypeFName = "float64"; + break; + default: + pyTypeFName = typeName; + } + PythonObject pyarrow = importPyArraow(); + PythonObject fieldF = pyarrow.attr("field"); + PythonObject pyArrowTypeF = pyarrow.attr(pyTypeFName); + PythonObject pyArrowType = pyArrowTypeF.call(); + PythonObject pyArrowField = fieldF.call(name, pyArrowType); + pyArrowType.del(); + pyArrowTypeF.del(); + fieldF.del(); + pyarrow.del(); + return pyArrowField; + } + + public static Field getFieldFromPythonObject(PythonObject pyArrowField) throws PythonException{ + PythonObject pyarrow = importPyArraow(); + PythonObject fieldType = pyarrow.attr("Field"); + if(!Python.isinstance(pyArrowField, fieldType)){ + pyarrow.del(); + fieldType.del(); + throw new PythonException("Expected pyarrow.Field, received " + Python.type(pyArrowField)); + } + PythonObject pyName = pyArrowField.attr("name"); + String name = pyName.toString(); + PythonObject pyTypeName = pyArrowField.attr("type"); + String typeName = pyTypeName.toString(); + DataType dt; + switch (typeName){ + case "bool": + dt = _boolean(); + break; + case "halffloat": + dt = float16(); + break; + case "float": + dt = float32(); + break; + case "double": + dt = float64(); + break; + default: + try{ + dt = (DataType)org.bytedeco.arrow.global.arrow.class.getMethod(typeName).invoke(null); + } + catch (Exception e){ + throw new PythonException("Unsupported type: " + typeName, e); + } + } + Field ret = new Field(name, dt); + pyarrow.del(); + fieldType.del(); + pyName.del(); + pyTypeName.del(); + return ret; + } + + public static PythonObject getPyArrowSchema(Schema schema) throws PythonException{ + Field[] fields = schema.fields().get(); + PythonObject[] pyFields = new PythonObject[fields.length]; + for (int i = 0; i < fields.length; i++){ + pyFields[i] = getPyArrowField(fields[i]); + } + PythonObject pyarrow = importPyArraow(); + PythonObject schemaF = pyarrow.attr("schema"); + PythonObject pySchema = schemaF.call(Python.list(pyFields)); + pyarrow.del(); + schemaF.del(); + return pySchema; + } + + public static Schema getSchemaFromPythonObject(PythonObject pyArrowSchema) throws PythonException{ + PythonObject pyarrow = importPyArraow(); + PythonObject schemaType = pyarrow.attr("Schema"); + if(!Python.isinstance(pyArrowSchema, schemaType)){ + pyarrow.del(); + schemaType.del(); + throw new PythonException("Expected pyarrow.Field, received " + Python.type(pyArrowSchema)); + } + PythonObject pySize = Python.len(pyArrowSchema); + int size = pySize.toInt(); + Field[] fields = new Field[size]; + for(int i = 0; i < size; i++){ + PythonObject pyField = pyArrowSchema.get(i); + fields[i] = getFieldFromPythonObject(pyField); + } + pySize.del(); + schemaType.del(); + pyarrow.del(); + return new Schema(new FieldVector(fields)); + } + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index e2d2e57477da..6359d444a9a3 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -20,6 +20,8 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; +import org.bytedeco.cpython.global.python; +import org.bytedeco.javacpp.Loader; import org.bytedeco.numpy.global.numpy; import org.nd4j.linalg.io.ClassPathResource; @@ -340,6 +342,18 @@ public static synchronized void initPythonPath() { if (path == null) { log.info("Setting python default path"); File[] packages = numpy.cachePackages(); + + //// TODO: fix in javacpp + File sitePackagesWindows = new File(python.cachePackage(), "site-packages"); + File[] packages2 = new File[packages.length + 1]; + for (int i = 0;i < packages.length; i++){ + System.out.println(packages[i].getAbsolutePath()); + packages2[i] = packages[i]; + } + packages2[packages.length] = sitePackagesWindows; + System.out.println(sitePackagesWindows.getAbsolutePath()); + packages = packages2; + ////////// Py_SetPath(packages); } else { log.info("Setting python path " + path); diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java new file mode 100644 index 000000000000..5fb5caf64bb8 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java @@ -0,0 +1,101 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import org.bytedeco.arrow.Field; +import org.bytedeco.arrow.FieldVector; +import org.bytedeco.arrow.Schema; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.bytedeco.arrow.global.arrow.*; + +public class TestPythonArrowUtils { + + @Test + public void testPyArrowImport() throws Exception{ + PythonArrowUtils.importPyArraow(); + } + + @Test + public void testINDArrayConversion() throws PythonException{ + INDArray array = Nd4j.rand(10); + PythonObject pyArrowArray = PythonArrowUtils.getPyArrowArrayFromINDArray(array); + INDArray array2 = PythonArrowUtils.getINDArrayFromPyArrowArray(pyArrowArray); + Assert.assertEquals(array, array2); + + // test no copy + Assert.assertEquals(array.data().address(), array2.data().address()); + array.putScalar(0, Nd4j.rand(1).getDouble(0)); + Assert.assertEquals(array, array2); + } + + @Test + public void testFieldConversion() throws PythonException{ + Field[] fields = new Field[]{ + new Field("a", int8()), + new Field("b", int16()), + new Field("c", int32()), + new Field("d", int64()), + new Field("e", uint8()), + new Field("f", uint16()), + new Field("g", uint32()), + new Field("h", uint64()), + new Field("i", _boolean()), + new Field("j", float16()), + new Field("k", float32()), + new Field("l", float64()), + }; + for (Field field: fields){ + PythonObject pyArrowField = PythonArrowUtils.getPyArrowField(field); + Field field2 = PythonArrowUtils.getFieldFromPythonObject(pyArrowField); + Assert.assertEquals(field.name(), field2.name()); + Assert.assertEquals(field.type(), field.type()); + } + } + + @Test + public void testSchemaConversion() throws PythonException{ + Field[] fields = new Field[]{ + new Field("a", int8()), + new Field("b", int16()), + new Field("c", int32()), + new Field("d", int64()), + new Field("e", uint8()), + new Field("f", uint16()), + new Field("g", uint32()), + new Field("h", uint64()), + new Field("i", _boolean()), + new Field("j", float16()), + new Field("k", float32()), + new Field("l", float64()), + }; + Schema schema = new Schema(new FieldVector(fields)); + PythonObject pySchema = PythonArrowUtils.getPyArrowSchema(schema); + Schema schema2 = PythonArrowUtils.getSchemaFromPythonObject(pySchema); + Field[] fields2 = schema2.fields().get(); + Assert.assertEquals(fields.length, fields2.length); + for(int i = 0;i < fields.length; i++){ + Assert.assertEquals(fields[i].name(), fields2[i].name()); + Assert.assertEquals(fields[i].type(), fields2[i].type()); + } + } + +} From ff6988f3e6f152208b6c64e4ba4f726fb39ff2e9 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Tue, 11 Feb 2020 23:00:58 +0400 Subject: [PATCH 04/10] headers --- .../src/main/java/org/datavec/python/PythonArrowUtils.java | 2 +- .../src/test/java/org/datavec/python/TestPythonArrowUtils.java | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java index a610dd297bba..87412930ec6d 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java index 5fb5caf64bb8..c51631483f6a 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,7 +14,6 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - package org.datavec.python; import org.bytedeco.arrow.Field; From 9300ed2e8ab69732f5ea1f99138abd97674134db Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Wed, 12 Feb 2020 07:39:08 +0400 Subject: [PATCH 05/10] table --- .../org/datavec/python/PythonArrowUtils.java | 25 +++++++++++++++---- .../datavec/python/TestPythonArrowUtils.java | 11 ++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java index 87412930ec6d..7388a98fd6c3 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java @@ -17,10 +17,8 @@ package org.datavec.python; -import org.bytedeco.arrow.DataType; -import org.bytedeco.arrow.Field; -import org.bytedeco.arrow.FieldVector; -import org.bytedeco.arrow.Schema; +import org.bytedeco.arrow.*; +import org.datavec.arrow.table.DataVecTable; import org.nd4j.linalg.api.ndarray.INDArray; import static org.bytedeco.arrow.global.arrow.*; @@ -29,7 +27,7 @@ public class PythonArrowUtils { static { try{ - importPyArraow().del(); // ensures that we are loading pyarrow's binary, not javacpp's + new Field("x", int32()); }catch (Exception e){ throw new RuntimeException(e); } @@ -190,4 +188,21 @@ public static Schema getSchemaFromPythonObject(PythonObject pyArrowSchema) throw return new Schema(new FieldVector(fields)); } + public static PythonObject getPyArrowTable(DataVecTable table) throws PythonException{ + PythonObject d = Python.dict(); + for(int i = 0; i < table.numColumns(); i++){ + PythonObject colName = new PythonObject(table.columnNameAt(i)); + PythonObject colArr = new PythonObject(table.column(i).toNdArray()); + d.set(colName, colArr); + //colName.del(); + //colArr.del(); + } + PythonObject pyarrow = importPyArraow(); + PythonObject tableF = pyarrow.attr("table"); + PythonObject pyTable = tableF.call(d); + pyarrow.del(); + tableF.del(); + return pyTable; + } + } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java index c51631483f6a..95c16ac97c7c 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java @@ -19,11 +19,14 @@ import org.bytedeco.arrow.Field; import org.bytedeco.arrow.FieldVector; import org.bytedeco.arrow.Schema; +import org.datavec.arrow.table.column.impl.*; import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import java.util.Arrays; + import static org.bytedeco.arrow.global.arrow.*; public class TestPythonArrowUtils { @@ -61,6 +64,7 @@ public void testFieldConversion() throws PythonException{ new Field("j", float16()), new Field("k", float32()), new Field("l", float64()), + new Field("m", binary()), }; for (Field field: fields){ PythonObject pyArrowField = PythonArrowUtils.getPyArrowField(field); @@ -85,6 +89,7 @@ public void testSchemaConversion() throws PythonException{ new Field("j", float16()), new Field("k", float32()), new Field("l", float64()), + new Field("m", binary()), }; Schema schema = new Schema(new FieldVector(fields)); PythonObject pySchema = PythonArrowUtils.getPyArrowSchema(schema); @@ -97,4 +102,10 @@ public void testSchemaConversion() throws PythonException{ } } + @Test + public void testTableConversion() throws PythonException{ + new DoubleColumn("double", new Double[]{1.0}); + + } + } From 51fc4c535e7f6be1ce1763e243c3518e24578f44 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Wed, 12 Feb 2020 14:46:13 +0400 Subject: [PATCH 06/10] test --- .../datavec/python/TestPythonArrowUtils.java | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java index 95c16ac97c7c..170c6553507d 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java @@ -19,15 +19,22 @@ import org.bytedeco.arrow.Field; import org.bytedeco.arrow.FieldVector; import org.bytedeco.arrow.Schema; +import org.datavec.api.transform.ColumnType; +import org.datavec.arrow.table.DataVecTable; +import org.datavec.arrow.table.column.DataVecColumn; import org.datavec.arrow.table.column.impl.*; +import org.datavec.arrow.table.row.Row; import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; +import java.util.List; import static org.bytedeco.arrow.global.arrow.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; public class TestPythonArrowUtils { @@ -108,4 +115,59 @@ public void testTableConversion() throws PythonException{ } + @Test + public void testTable() { + int count = 0; + ColumnType[] columnTypes = new ColumnType[] { + ColumnType.Integer, + ColumnType.Double, + ColumnType.Float, + ColumnType.Boolean, + ColumnType.String + }; + + DataVecColumn[] dataVecColumns = new DataVecColumn[columnTypes.length]; + DataVecColumn[] dataVecColumnsList = new DataVecColumn[columnTypes.length]; + + for(ColumnType columnType : columnTypes) { + switch(columnType) { + case Double: + dataVecColumns[count] = new DoubleColumn(columnType.name().toLowerCase(),new Double[]{1.0}); + dataVecColumnsList[count] = new DoubleColumn(columnType.name().toLowerCase(), Arrays.asList(1.0)); + break; + case Float: + dataVecColumns[count] = new FloatColumn(columnType.name().toLowerCase(),new Float[]{1.0f}); + dataVecColumnsList[count] = new FloatColumn(columnType.name().toLowerCase(),Arrays.asList(1.0f)); + break; + case Boolean: + dataVecColumns[count] = new BooleanColumn(columnType.name().toLowerCase(),new Boolean[]{true}); + dataVecColumnsList[count] = new BooleanColumn(columnType.name().toLowerCase(),Arrays.asList(true)); + break; + case String: + dataVecColumns[count] = new StringColumn(columnType.name().toLowerCase(),new String[]{"1.0"}); + dataVecColumnsList[count] = new StringColumn(columnType.name().toLowerCase(),Arrays.asList("1.0")); + break; + case Long: + dataVecColumns[count] = new LongColumn(columnType.name().toLowerCase(),new Long[]{1L}); + dataVecColumnsList[count] = new LongColumn(columnType.name().toLowerCase(),Arrays.asList(1L)); + break; + case Integer: + dataVecColumns[count] = new IntColumn(columnType.name().toUpperCase(),new Integer[]{1}); + dataVecColumnsList[count] = new IntColumn(columnType.name().toUpperCase(),Arrays.asList(1)); + break; + + } + + assertEquals(1,dataVecColumns[count].rows()); + assertEquals("Column type of " + columnType + " has wrong number of rows",1,dataVecColumns[count].rows()); + count++; + } + + DataVecTable dataVecTable1 = DataVecTable.create(dataVecColumns); + assertEquals(columnTypes.length,dataVecTable1.numColumns()); + DataVecTable dataVecTableList = DataVecTable.create(dataVecColumnsList); + + + } + } From 56e3f921ead3356b5d13a55566e38d60bcf7fb8c Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Mon, 17 Feb 2020 01:38:27 +0400 Subject: [PATCH 07/10] updates --- .../main/java/org/datavec/python/Python.java | 8 +- .../org/datavec/python/PythonArrowUtils.java | 128 +++++++++++++++--- .../org/datavec/python/PythonProcess.java | 103 ++++++++++++++ .../datavec/python/TestPythonArrowUtils.java | 104 +++++++------- .../org/datavec/python/TestPythonProcess.java | 24 ++++ 5 files changed, 296 insertions(+), 71 deletions(-) create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java index d0d70b3b2438..93b1c8622a21 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java @@ -125,7 +125,7 @@ public static PythonObject dict(PythonObject pythonObject) { } public static PythonObject dict() { - return attr("dict").call(); + return new PythonObject(PyDict_New()); } public static PythonObject dictType() { @@ -282,5 +282,11 @@ public static PythonGIL lock(){ return PythonGIL.lock(); } + public static void setVariable(String name, PythonObject value) throws PythonException{ + PythonExecutioner.setVariable(name, value); + } + public static PythonObject getVariable(String name){ + return PythonExecutioner.getVariable(name); + } } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java index 7388a98fd6c3..83df7fcf4f77 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java @@ -17,23 +17,56 @@ package org.datavec.python; +import lombok.extern.slf4j.Slf4j; import org.bytedeco.arrow.*; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Loader; import org.datavec.arrow.table.DataVecTable; +import org.nd4j.arrow.ByteDecoArrowSerde; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.UUID; import static org.bytedeco.arrow.global.arrow.*; +@Slf4j public class PythonArrowUtils { static { - try{ - new Field("x", int32()); - }catch (Exception e){ - throw new RuntimeException(e); - } + init(); + } + private static String PYARROW = "pyarrow"; + private static String PANDAS = "pandas"; + private static String REQUIRED_PYARROW_VERSION = "0.15.1"; // TODO get version from pom.xml? + private static boolean init = false; + + public static void init(){ + // TODO: Find out why this works + INDArray dummyArr = Nd4j.rand(5); + new DoubleArray(new ArrowBuffer(new BytePointer(dummyArr.data().pointer()), dummyArr.data().length())); } + public static PythonObject importPyArraow() throws PythonException{ - return Python.importModule("pyarrow"); + try{ + if (!PythonProcess.isPackageInstalled(PYARROW)){ + log.info("PyArrow is not installed. Attempting to pip install pyarrow " + REQUIRED_PYARROW_VERSION); + PythonProcess.pipInstall(PYARROW, REQUIRED_PYARROW_VERSION); + PythonProcess.pipInstall(PANDAS); + } + else{ + String pkgVersion = PythonProcess.getPackageVersion(PYARROW); + if (!pkgVersion.equals(REQUIRED_PYARROW_VERSION)) { + log.info("Required pyarrow version " + REQUIRED_PYARROW_VERSION + " but current version is " + pkgVersion + ". Attempting reinstall..."); + PythonProcess.pipInstall(PYARROW, REQUIRED_PYARROW_VERSION); + PythonProcess.pipInstall(PANDAS); + } + } + } catch(Exception e){ + throw new PythonException("Error verifying/installing pyarrow package.", e); + } + + return Python.importModule(PYARROW); } public static PythonObject getPyArrowArrayFromINDArray(INDArray arr) throws PythonException{ @@ -48,7 +81,7 @@ public static PythonObject getPyArrowArrayFromINDArray(INDArray arr) throws Pyth } public static INDArray getINDArrayFromPyArrowArray(PythonObject arr) throws PythonException{ - PythonObject pyArrow = Python.importModule("pyarrow"); + PythonObject pyArrow = Python.importModule(PYARROW); PythonObject arrayType = pyArrow.attr("Array"); if (!Python.isinstance(arr, arrayType)){ pyArrow.del(); @@ -64,17 +97,6 @@ public static INDArray getINDArrayFromPyArrowArray(PythonObject arr) throws Pyth npArr.del(); return ret; } -// private static void installPyArrow() throws Exception{ -// try{ -// importPyArraow().del(); -// }catch (PythonException pe){ -// String python = Loader.load(org.bytedeco.cpython.python.class); -// ProcessBuilder pb = new ProcessBuilder(python, "-m", "pip", "install", "pyarrow==0.15.1"); -// pb.inheritIO().start().waitFor(); -// importPyArraow().del(); -// } -// } - public static PythonObject getPyArrowField(Field field) throws PythonException{ String name = field.name(); @@ -188,14 +210,80 @@ public static Schema getSchemaFromPythonObject(PythonObject pyArrowSchema) throw return new Schema(new FieldVector(fields)); } + + public static PythonObject getPyArrowTable(Table table) throws PythonException{ + PythonObject d = Python.dict(); + Schema schema =table.schema(); + Field[] fields = schema.fields().get(); + for (int i = 0; i < fields.length; i++){ + String colName = fields[i].name(); + PythonObject pyColName = new PythonObject(colName); + ChunkedArray chunkedArray = table.column(i); + INDArray arr = Nd4j.create(ByteDecoArrowSerde.fromArrowBuffer(chunkedArray.chunk(0).null_bitmap(), fields[i].type())); + PythonObject pyArr = new PythonObject(arr); + d.set(pyColName, pyArr); + } + PythonObject pyarrow = importPyArraow(); + PythonObject tableF = pyarrow.attr("table"); + PythonObject pyTable = tableF.call(d); + pyarrow.del(); + tableF.del(); + return pyTable; + } + + public static Table getTableFromPythonObject(PythonObject pyTable) throws PythonException{ + PythonObject pyarrow = importPyArraow(); + PythonObject tableType = pyarrow.attr("Table"); + if (!Python.isinstance(pyTable, tableType)){ + if (Python.isinstance(pyTable, Python.dictType())){ + PythonObject orig = pyTable; + PythonObject tableF = pyarrow.attr("table"); + pyTable = tableF.call(pyTable); + orig.del(); + tableF.del(); + } + else { + throw new PythonException("Expected pyarrow.lib.Table or dict, received " + Python.type(pyTable)); + } + } + PythonObject pySchema = pyTable.attr("schema"); + PythonObject pyShemaSize = Python.len(pySchema); + Field[] fields = new Field[pyShemaSize.toInt()]; + Array[] arrays = new FlatArray[fields.length]; + String origContext = Python.getCurrentContext(); + String tempContext = 'a' + UUID.randomUUID().toString().replace('-','_' ); + Python.setContext(tempContext); + for (int i = 0; i < fields.length; i++){ + Python.setVariable("col", pyTable.get(i)); + Python.exec("arr=col.to_pandas().to_numpy()"); + INDArray indArray = Python.getVariable("arr").toNumpy().getNd4jArray(); + fields[i] = getFieldFromPythonObject(pySchema.get(i)); + arrays[i] = new DoubleArray(new ArrowBuffer(new BytePointer(indArray.data().pointer()), indArray.data().length())); + } + + Python.setContext(origContext); + Python.deleteContext(tempContext); + FieldVector fieldVector = new FieldVector(fields); + Schema schema = new Schema(fieldVector); + ArrayVector arrayVector = new ArrayVector(arrays); + Table ret = Table.Make(schema, arrayVector); + pySchema.del(); + pyShemaSize.del(); + tableType.del(); + pyarrow.del(); + return ret; + + } + public static PythonObject getPyArrowTable(DataVecTable table) throws PythonException{ + PythonObject d = Python.dict(); for(int i = 0; i < table.numColumns(); i++){ PythonObject colName = new PythonObject(table.columnNameAt(i)); PythonObject colArr = new PythonObject(table.column(i).toNdArray()); d.set(colName, colArr); - //colName.del(); - //colArr.del(); + colName.del(); + colArr.del(); } PythonObject pyarrow = importPyArraow(); PythonObject tableF = pyarrow.attr("table"); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java new file mode 100644 index 000000000000..064933440244 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java @@ -0,0 +1,103 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.IOUtils; +import org.bytedeco.javacpp.Loader; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +@Slf4j +public class PythonProcess { + private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); + public static String runAndReturn(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + log.info("Executing command: " + Arrays.toString(allArgs)); + ProcessBuilder pb = new ProcessBuilder(allArgs); + Process process = pb.start(); + String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8); + process.waitFor(); + return out; + + } + + public static void run(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + log.info("Executing command: " + Arrays.toString(allArgs)); + ProcessBuilder pb = new ProcessBuilder(allArgs); + pb.inheritIO().start().waitFor(); + } + public static void pipInstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "install", packageName); + }catch(Exception e){ + throw new PythonException("Error installing package " + packageName, e); + } + + } + + public static void pipInstall(String packageName, String version) throws PythonException{ + pipInstall(packageName + "==" + version); + } + + public static void pipUninstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "uninstall", packageName); + }catch(Exception e){ + throw new PythonException("Error uninstalling package " + packageName, e); + } + + } + public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{ + if (!gitRepoUrl.contains("://")){ + gitRepoUrl = "git://" + gitRepoUrl; + } + try{ + run("-m", "pip", "install", "git+", gitRepoUrl); + }catch(Exception e){ + throw new PythonException("Error installing package from " +gitRepoUrl, e); + } + + } + + public static String getPackageVersion(String packageName) throws IOException, InterruptedException, PythonException{ + String out = runAndReturn("-m", "pip", "show", packageName); + if (!out.contains("Version: ")){ + throw new PythonException("Can't find package " + packageName); + } + String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0]; + return pkgVersion; + } + + public static boolean isPackageInstalled(String packageName)throws IOException, InterruptedException{ + String out = runAndReturn("-m", "pip", "show", packageName); + return !out.isEmpty(); + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java index 170c6553507d..299ce3676cbe 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java @@ -19,18 +19,22 @@ import org.bytedeco.arrow.Field; import org.bytedeco.arrow.FieldVector; import org.bytedeco.arrow.Schema; +import org.bytedeco.arrow.Table; import org.datavec.api.transform.ColumnType; import org.datavec.arrow.table.DataVecTable; import org.datavec.arrow.table.column.DataVecColumn; import org.datavec.arrow.table.column.impl.*; import org.datavec.arrow.table.row.Row; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.bytedeco.arrow.global.arrow.*; import static org.junit.Assert.assertEquals; @@ -116,58 +120,58 @@ public void testTableConversion() throws PythonException{ } @Test - public void testTable() { - int count = 0; - ColumnType[] columnTypes = new ColumnType[] { - ColumnType.Integer, - ColumnType.Double, - ColumnType.Float, - ColumnType.Boolean, - ColumnType.String - }; - - DataVecColumn[] dataVecColumns = new DataVecColumn[columnTypes.length]; - DataVecColumn[] dataVecColumnsList = new DataVecColumn[columnTypes.length]; - - for(ColumnType columnType : columnTypes) { - switch(columnType) { - case Double: - dataVecColumns[count] = new DoubleColumn(columnType.name().toLowerCase(),new Double[]{1.0}); - dataVecColumnsList[count] = new DoubleColumn(columnType.name().toLowerCase(), Arrays.asList(1.0)); - break; - case Float: - dataVecColumns[count] = new FloatColumn(columnType.name().toLowerCase(),new Float[]{1.0f}); - dataVecColumnsList[count] = new FloatColumn(columnType.name().toLowerCase(),Arrays.asList(1.0f)); - break; - case Boolean: - dataVecColumns[count] = new BooleanColumn(columnType.name().toLowerCase(),new Boolean[]{true}); - dataVecColumnsList[count] = new BooleanColumn(columnType.name().toLowerCase(),Arrays.asList(true)); - break; - case String: - dataVecColumns[count] = new StringColumn(columnType.name().toLowerCase(),new String[]{"1.0"}); - dataVecColumnsList[count] = new StringColumn(columnType.name().toLowerCase(),Arrays.asList("1.0")); - break; - case Long: - dataVecColumns[count] = new LongColumn(columnType.name().toLowerCase(),new Long[]{1L}); - dataVecColumnsList[count] = new LongColumn(columnType.name().toLowerCase(),Arrays.asList(1L)); - break; - case Integer: - dataVecColumns[count] = new IntColumn(columnType.name().toUpperCase(),new Integer[]{1}); - dataVecColumnsList[count] = new IntColumn(columnType.name().toUpperCase(),Arrays.asList(1)); - break; - - } - - assertEquals(1,dataVecColumns[count].rows()); - assertEquals("Column type of " + columnType + " has wrong number of rows",1,dataVecColumns[count].rows()); - count++; - } - - DataVecTable dataVecTable1 = DataVecTable.create(dataVecColumns); - assertEquals(columnTypes.length,dataVecTable1.numColumns()); - DataVecTable dataVecTableList = DataVecTable.create(dataVecColumnsList); + public void testTables()throws Exception{ + PythonArrowUtils.init(); + Map map = new HashMap<>(); + map.put("a", Nd4j.zeros(5)); + map.put("b", Nd4j.ones(5)); + map.put("c", Nd4j.rand(5)); + PythonObject d = new PythonObject(map); + Table table = PythonArrowUtils.getTableFromPythonObject(d); } + @Test + public void testTable() throws Exception{ + PythonArrowUtils.init(); + ColumnType[] columnTypes = new ColumnType[] { + ColumnType.Integer, + ColumnType.Double, + ColumnType.Float, + ColumnType.Boolean, + ColumnType.String + }; + + DataVecColumn[] dataVecColumns = new DataVecColumn[columnTypes.length]; + for (int i = 0; i < columnTypes.length; i++){ + ColumnType columnType = columnTypes[i]; + switch(columnType) { + case Double: + dataVecColumns[i] = new DoubleColumn(columnType.name().toLowerCase(),new Double[]{1.0}); + break; + case Float: + dataVecColumns[i] = new FloatColumn(columnType.name().toLowerCase(),new Float[]{1.0f}); + break; + case Boolean: + dataVecColumns[i] = new BooleanColumn(columnType.name().toLowerCase(),new Boolean[]{true}); + break; + case String: + dataVecColumns[i] = new StringColumn(columnType.name().toLowerCase(),new String[]{"1.0"}); + break; + case Long: + dataVecColumns[i] = new LongColumn(columnType.name().toLowerCase(),new Long[]{1L}); + break; + case Integer: + dataVecColumns[i] = new IntColumn(columnType.name().toUpperCase(),new Integer[]{1}); + break; + + } + } + DataVecTable dataVecTable = DataVecTable.create(dataVecColumns); + PythonObject pyArrowTable = PythonArrowUtils.getPyArrowTable(dataVecTable); + + + } + } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java new file mode 100644 index 000000000000..7443b4dddc29 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java @@ -0,0 +1,24 @@ +package org.datavec.python; + +import org.junit.Assert; +import org.junit.Test; + +public class TestPythonProcess { + + @Test + public void testPythonProcess() throws Exception{ + String stdout = PythonProcess.runAndReturn("-m", "pip", "list"); + System.out.println(stdout); + Assert.assertTrue(stdout.replace(" ", "").contains("PackageVersion")); + } + @Test + public void testPackageVersion() throws Exception{ + System.out.println(PythonProcess.getPackageVersion("numpy")); + } + + @Test + public void testPackageInstalledCheck() throws Exception{ + Assert.assertFalse(PythonProcess.isPackageInstalled("abcdefgh")); + } + +} From 5888d7b0e2a95002a297790eeb1ae6a8d53b4e12 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Mon, 17 Feb 2020 02:14:46 +0400 Subject: [PATCH 08/10] test --- datavec/datavec-python/pom.xml | 23 +++-- .../datavec/python/TestPythonArrowUtils.java | 97 ++++++++++++++++++- libnd4j/pom.xml | 8 +- 3 files changed, 109 insertions(+), 19 deletions(-) diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml index 315aa94d0ed7..e0f5ccd437da 100644 --- a/datavec/datavec-python/pom.xml +++ b/datavec/datavec-python/pom.xml @@ -47,16 +47,6 @@ nd4j-arrow ${nd4j.version} - - org.bytedeco - arrow-platform - ${arrow.javacpp.version} - - - com.google.code.findbugs - jsr305 - 3.0.2 - org.datavec datavec-api @@ -88,4 +78,17 @@ test-nd4j-cuda-10.2 + + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 + + + + diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java index 299ce3676cbe..887bc2d1b581 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java @@ -26,6 +26,8 @@ import org.datavec.arrow.table.column.impl.*; import org.datavec.arrow.table.row.Row; import org.junit.Assert; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; @@ -37,9 +39,6 @@ import java.util.Map; import static org.bytedeco.arrow.global.arrow.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - public class TestPythonArrowUtils { @Test @@ -120,7 +119,7 @@ public void testTableConversion() throws PythonException{ } @Test - public void testTables()throws Exception{ + public void testTableFromDict()throws Exception{ PythonArrowUtils.init(); Map map = new HashMap<>(); map.put("a", Nd4j.zeros(5)); @@ -133,7 +132,7 @@ public void testTables()throws Exception{ } @Test - public void testTable() throws Exception{ + public void testPyArrowTable() throws Exception{ PythonArrowUtils.init(); ColumnType[] columnTypes = new ColumnType[] { ColumnType.Integer, @@ -174,4 +173,92 @@ public void testTable() throws Exception{ } + @Test + public void testTable() { + PythonArrowUtils.init(); + int count = 0; + ColumnType[] columnTypes = new ColumnType[] { + ColumnType.Integer, + ColumnType.Double, + ColumnType.Float, + ColumnType.Boolean, + ColumnType.String + }; + + DataVecColumn[] dataVecColumns = new DataVecColumn[columnTypes.length]; + DataVecColumn[] dataVecColumnsList = new DataVecColumn[columnTypes.length]; + + for(ColumnType columnType : columnTypes) { + switch(columnType) { + case Double: + dataVecColumns[count] = new DoubleColumn(columnType.name().toLowerCase(),new Double[]{1.0}); + dataVecColumnsList[count] = new DoubleColumn(columnType.name().toLowerCase(), Arrays.asList(1.0)); + break; + case Float: + dataVecColumns[count] = new FloatColumn(columnType.name().toLowerCase(),new Float[]{1.0f}); + dataVecColumnsList[count] = new FloatColumn(columnType.name().toLowerCase(),Arrays.asList(1.0f)); + break; + case Boolean: + dataVecColumns[count] = new BooleanColumn(columnType.name().toLowerCase(),new Boolean[]{true}); + dataVecColumnsList[count] = new BooleanColumn(columnType.name().toLowerCase(),Arrays.asList(true)); + break; + case String: + dataVecColumns[count] = new StringColumn(columnType.name().toLowerCase(),new String[]{"1.0"}); + dataVecColumnsList[count] = new StringColumn(columnType.name().toLowerCase(),Arrays.asList("1.0")); + break; + case Long: + dataVecColumns[count] = new LongColumn(columnType.name().toLowerCase(),new Long[]{1L}); + dataVecColumnsList[count] = new LongColumn(columnType.name().toLowerCase(),Arrays.asList(1L)); + break; + case Integer: + dataVecColumns[count] = new IntColumn(columnType.name().toUpperCase(),new Integer[]{1}); + dataVecColumnsList[count] = new IntColumn(columnType.name().toUpperCase(),Arrays.asList(1)); + break; + + } + + assertEquals(1,dataVecColumns[count].rows()); + assertEquals("Column type of " + columnType + " has wrong number of rows",1,dataVecColumns[count].rows()); + count++; + } + + DataVecTable dataVecTable1 = DataVecTable.create(dataVecColumns); + assertEquals(columnTypes.length,dataVecTable1.numColumns()); + DataVecTable dataVecTableList = DataVecTable.create(dataVecColumnsList); + + Row row = dataVecTable1.row(0); + Row row2 = dataVecTableList.row(0); + assertEquals(1.0d, row.elementAtColumn("double"),1e-3); + assertEquals(1.0f, row.elementAtColumn("float"),1e-3f); + assertEquals("1.0",row.elementAtColumn("string")); + assertEquals(true, row.elementAtColumn("boolean")); + + assertEquals(1.0d, row2.elementAtColumn("double"),1e-3); + assertEquals(1.0f, row2.elementAtColumn("float"),1e-3f); + assertEquals("1.0",row2.elementAtColumn("string")); + assertEquals(true, row2.elementAtColumn("boolean")); + + + + for(int i = 0; i < row.columnNames().size(); i++) { + assertTrue(row.elementAtColumn(i).equals(row.elementAtColumn(row.columnNames().get(i)))); + assertTrue(dataVecTable1.column(i).contains(row.elementAtColumn(i))); + INDArray arr2 = dataVecTable1.column(i).toNdArray(); + assertEquals(dataVecTable1.column(1).rows(),arr2.length()); + List list = dataVecTable1.column(i).toList(); + assertEquals(1,list.size()); + + + assertTrue(row2.elementAtColumn(i).equals(row2.elementAtColumn(row2.columnNames().get(i)))); + assertTrue(dataVecTableList.column(i).contains(row2.elementAtColumn(i))); + arr2 = dataVecTableList.column(i).toNdArray(); + assertEquals(dataVecTableList.column(1).rows(),arr2.length()); + list = dataVecTableList.column(i).toList(); + assertEquals(1,list.size()); + + } + + assertEquals(1,dataVecTable1.numRows()); + } + } diff --git a/libnd4j/pom.xml b/libnd4j/pom.xml index 20b9d65623be..09de5a7d5894 100644 --- a/libnd4j/pom.xml +++ b/libnd4j/pom.xml @@ -64,7 +64,7 @@ - 10.2 + 10.1 7.6 release cpu @@ -150,7 +150,7 @@ ${libnd4j.cpu.compile.skip} - bash + C:\msys64\usr\bin\bash.exe ${project.basedir}/buildnativeoperations.sh --build-type ${libnd4j.build} @@ -183,7 +183,7 @@ ${libnd4j.test.skip} ${basedir}/tests_cpu - bash + C:\msys64\usr\bin\bash.exe run_tests.sh --chip ${libnd4j.chip} @@ -310,7 +310,7 @@ ${libnd4j.cuda.compile.skip} - bash + C:\msys64\usr\bin\bash.exe ${project.basedir}/buildnativeoperations.sh --build-type From 4c959cee9fae32074b2ff2a6eb146e2aff328951 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Mon, 17 Feb 2020 02:21:48 +0400 Subject: [PATCH 09/10] reverse commit --- libnd4j/pom.xml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/libnd4j/pom.xml b/libnd4j/pom.xml index 09de5a7d5894..e7c26f99a8a2 100644 --- a/libnd4j/pom.xml +++ b/libnd4j/pom.xml @@ -17,8 +17,8 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> org.deeplearning4j @@ -64,7 +64,7 @@ - 10.1 + 10.2 7.6 release cpu @@ -138,7 +138,7 @@ javacpp-cppbuild-validate validate - build + build @@ -150,7 +150,7 @@ ${libnd4j.cpu.compile.skip} - C:\msys64\usr\bin\bash.exe + bash ${project.basedir}/buildnativeoperations.sh --build-type ${libnd4j.build} @@ -183,7 +183,7 @@ ${libnd4j.test.skip} ${basedir}/tests_cpu - C:\msys64\usr\bin\bash.exe + bash run_tests.sh --chip ${libnd4j.chip} @@ -310,7 +310,7 @@ ${libnd4j.cuda.compile.skip} - C:\msys64\usr\bin\bash.exe + bash ${project.basedir}/buildnativeoperations.sh --build-type From 1c71d578247afe2ddd1b9a1367c21515c1dba54d Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Wed, 19 Feb 2020 16:54:17 +0400 Subject: [PATCH 10/10] rem init --- .../src/test/java/org/datavec/python/TestPythonArrowUtils.java | 1 - 1 file changed, 1 deletion(-) diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java index 887bc2d1b581..2bfce78695b7 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java @@ -175,7 +175,6 @@ public void testPyArrowTable() throws Exception{ @Test public void testTable() { - PythonArrowUtils.init(); int count = 0; ColumnType[] columnTypes = new ColumnType[] { ColumnType.Integer,