diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml index 55cf6c5dad1e..e0f5ccd437da 100644 --- a/datavec/datavec-python/pom.xml +++ b/datavec/datavec-python/pom.xml @@ -42,17 +42,21 @@ numpy-platform ${numpy.javacpp.version} - - com.google.code.findbugs - jsr305 - 3.0.2 + org.nd4j + nd4j-arrow + ${nd4j.version} org.datavec datavec-api ${project.version} + + org.datavec + datavec-arrow + ${project.version} + ch.qos.logback logback-classic @@ -74,4 +78,17 @@ test-nd4j-cuda-10.2 + + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 + + + + diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java index 9dabbef2d330..93b1c8622a21 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java @@ -112,12 +112,20 @@ public static PythonObject listType() { return attr("list"); } + public static PythonObject list(PythonObject[] pythonObjects){ + PyObject list = PyList_New(pythonObjects.length); + for(int i = 0;i < pythonObjects.length; i++){ + PyList_SetItem(list, i, pythonObjects[i].getNativePythonObject()); + } + return new PythonObject(list); + } + public static PythonObject dict(PythonObject pythonObject) { return attr("dict").call(pythonObject); } public static PythonObject dict() { - return attr("dict").call(); + return new PythonObject(PyDict_New()); } public static PythonObject dictType() { @@ -172,6 +180,14 @@ public static PythonObject tuple() { return attr("tuple").call(); } + public static PythonObject tuple(PythonObject[] pythonObjects){ + PyObject tuple = PyTuple_New(pythonObjects.length); + for(int i = 0;i < pythonObjects.length; i++){ + PyTuple_SetItem(tuple, i, pythonObjects[i].getNativePythonObject()); + } + return new PythonObject(tuple); + } + public static PythonObject Exception(PythonObject pythonObject) { return attr("Exception").call(pythonObject); @@ -266,5 +282,11 @@ public static PythonGIL lock(){ return PythonGIL.lock(); } + public static void setVariable(String name, PythonObject value) throws PythonException{ + PythonExecutioner.setVariable(name, value); + } + public static PythonObject getVariable(String name){ + return PythonExecutioner.getVariable(name); + } } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java new file mode 100644 index 000000000000..83df7fcf4f77 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java @@ -0,0 +1,296 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.bytedeco.arrow.*; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Loader; +import org.datavec.arrow.table.DataVecTable; +import org.nd4j.arrow.ByteDecoArrowSerde; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.UUID; + +import static org.bytedeco.arrow.global.arrow.*; + +@Slf4j +public class PythonArrowUtils { + + static { + init(); + } + private static String PYARROW = "pyarrow"; + private static String PANDAS = "pandas"; + private static String REQUIRED_PYARROW_VERSION = "0.15.1"; // TODO get version from pom.xml? + private static boolean init = false; + + public static void init(){ + // TODO: Find out why this works + INDArray dummyArr = Nd4j.rand(5); + new DoubleArray(new ArrowBuffer(new BytePointer(dummyArr.data().pointer()), dummyArr.data().length())); + } + + public static PythonObject importPyArraow() throws PythonException{ + try{ + if (!PythonProcess.isPackageInstalled(PYARROW)){ + log.info("PyArrow is not installed. Attempting to pip install pyarrow " + REQUIRED_PYARROW_VERSION); + PythonProcess.pipInstall(PYARROW, REQUIRED_PYARROW_VERSION); + PythonProcess.pipInstall(PANDAS); + } + else{ + String pkgVersion = PythonProcess.getPackageVersion(PYARROW); + if (!pkgVersion.equals(REQUIRED_PYARROW_VERSION)) { + log.info("Required pyarrow version " + REQUIRED_PYARROW_VERSION + " but current version is " + pkgVersion + ". Attempting reinstall..."); + PythonProcess.pipInstall(PYARROW, REQUIRED_PYARROW_VERSION); + PythonProcess.pipInstall(PANDAS); + } + } + } catch(Exception e){ + throw new PythonException("Error verifying/installing pyarrow package.", e); + } + + return Python.importModule(PYARROW); + } + + public static PythonObject getPyArrowArrayFromINDArray(INDArray arr) throws PythonException{ + PythonObject pyarrow = importPyArraow(); + PythonObject npArr = new PythonObject(arr); + PythonObject arrayF = pyarrow.attr("array"); + PythonObject ret = arrayF.call(npArr); + pyarrow.del(); + npArr.del(); + arrayF.del(); + return ret; + } + + public static INDArray getINDArrayFromPyArrowArray(PythonObject arr) throws PythonException{ + PythonObject pyArrow = Python.importModule(PYARROW); + PythonObject arrayType = pyArrow.attr("Array"); + if (!Python.isinstance(arr, arrayType)){ + pyArrow.del(); + arrayType.del(); + throw new PythonException("Expected pyarrow.Array, received " + Python.type(arr)); + } + PythonObject toNumpyF = arr.attr("to_numpy"); + PythonObject npArr = toNumpyF.call(); + pyArrow.del(); + arrayType.del(); + toNumpyF.del(); + INDArray ret = npArr.toNumpy().getNd4jArray(); + npArr.del(); + return ret; + } + + public static PythonObject getPyArrowField(Field field) throws PythonException{ + String name = field.name(); + org.bytedeco.arrow.DataType type = field.type(); + String typeName = type.name(); + String pyTypeFName; + switch (typeName){ + case "list": + throw new PythonException("Unsupported field type: list"); + case "bool": + pyTypeFName = "bool_"; + break; + case "halffloat": + pyTypeFName = "float16"; + break; + case "float": + pyTypeFName = "float32"; + break; + case "double": + pyTypeFName = "float64"; + break; + default: + pyTypeFName = typeName; + } + PythonObject pyarrow = importPyArraow(); + PythonObject fieldF = pyarrow.attr("field"); + PythonObject pyArrowTypeF = pyarrow.attr(pyTypeFName); + PythonObject pyArrowType = pyArrowTypeF.call(); + PythonObject pyArrowField = fieldF.call(name, pyArrowType); + pyArrowType.del(); + pyArrowTypeF.del(); + fieldF.del(); + pyarrow.del(); + return pyArrowField; + } + + public static Field getFieldFromPythonObject(PythonObject pyArrowField) throws PythonException{ + PythonObject pyarrow = importPyArraow(); + PythonObject fieldType = pyarrow.attr("Field"); + if(!Python.isinstance(pyArrowField, fieldType)){ + pyarrow.del(); + fieldType.del(); + throw new PythonException("Expected pyarrow.Field, received " + Python.type(pyArrowField)); + } + PythonObject pyName = pyArrowField.attr("name"); + String name = pyName.toString(); + PythonObject pyTypeName = pyArrowField.attr("type"); + String typeName = pyTypeName.toString(); + DataType dt; + switch (typeName){ + case "bool": + dt = _boolean(); + break; + case "halffloat": + dt = float16(); + break; + case "float": + dt = float32(); + break; + case "double": + dt = float64(); + break; + default: + try{ + dt = (DataType)org.bytedeco.arrow.global.arrow.class.getMethod(typeName).invoke(null); + } + catch (Exception e){ + throw new PythonException("Unsupported type: " + typeName, e); + } + } + Field ret = new Field(name, dt); + pyarrow.del(); + fieldType.del(); + pyName.del(); + pyTypeName.del(); + return ret; + } + + public static PythonObject getPyArrowSchema(Schema schema) throws PythonException{ + Field[] fields = schema.fields().get(); + PythonObject[] pyFields = new PythonObject[fields.length]; + for (int i = 0; i < fields.length; i++){ + pyFields[i] = getPyArrowField(fields[i]); + } + PythonObject pyarrow = importPyArraow(); + PythonObject schemaF = pyarrow.attr("schema"); + PythonObject pySchema = schemaF.call(Python.list(pyFields)); + pyarrow.del(); + schemaF.del(); + return pySchema; + } + + public static Schema getSchemaFromPythonObject(PythonObject pyArrowSchema) throws PythonException{ + PythonObject pyarrow = importPyArraow(); + PythonObject schemaType = pyarrow.attr("Schema"); + if(!Python.isinstance(pyArrowSchema, schemaType)){ + pyarrow.del(); + schemaType.del(); + throw new PythonException("Expected pyarrow.Field, received " + Python.type(pyArrowSchema)); + } + PythonObject pySize = Python.len(pyArrowSchema); + int size = pySize.toInt(); + Field[] fields = new Field[size]; + for(int i = 0; i < size; i++){ + PythonObject pyField = pyArrowSchema.get(i); + fields[i] = getFieldFromPythonObject(pyField); + } + pySize.del(); + schemaType.del(); + pyarrow.del(); + return new Schema(new FieldVector(fields)); + } + + + public static PythonObject getPyArrowTable(Table table) throws PythonException{ + PythonObject d = Python.dict(); + Schema schema =table.schema(); + Field[] fields = schema.fields().get(); + for (int i = 0; i < fields.length; i++){ + String colName = fields[i].name(); + PythonObject pyColName = new PythonObject(colName); + ChunkedArray chunkedArray = table.column(i); + INDArray arr = Nd4j.create(ByteDecoArrowSerde.fromArrowBuffer(chunkedArray.chunk(0).null_bitmap(), fields[i].type())); + PythonObject pyArr = new PythonObject(arr); + d.set(pyColName, pyArr); + } + PythonObject pyarrow = importPyArraow(); + PythonObject tableF = pyarrow.attr("table"); + PythonObject pyTable = tableF.call(d); + pyarrow.del(); + tableF.del(); + return pyTable; + } + + public static Table getTableFromPythonObject(PythonObject pyTable) throws PythonException{ + PythonObject pyarrow = importPyArraow(); + PythonObject tableType = pyarrow.attr("Table"); + if (!Python.isinstance(pyTable, tableType)){ + if (Python.isinstance(pyTable, Python.dictType())){ + PythonObject orig = pyTable; + PythonObject tableF = pyarrow.attr("table"); + pyTable = tableF.call(pyTable); + orig.del(); + tableF.del(); + } + else { + throw new PythonException("Expected pyarrow.lib.Table or dict, received " + Python.type(pyTable)); + } + } + PythonObject pySchema = pyTable.attr("schema"); + PythonObject pyShemaSize = Python.len(pySchema); + Field[] fields = new Field[pyShemaSize.toInt()]; + Array[] arrays = new FlatArray[fields.length]; + String origContext = Python.getCurrentContext(); + String tempContext = 'a' + UUID.randomUUID().toString().replace('-','_' ); + Python.setContext(tempContext); + for (int i = 0; i < fields.length; i++){ + Python.setVariable("col", pyTable.get(i)); + Python.exec("arr=col.to_pandas().to_numpy()"); + INDArray indArray = Python.getVariable("arr").toNumpy().getNd4jArray(); + fields[i] = getFieldFromPythonObject(pySchema.get(i)); + arrays[i] = new DoubleArray(new ArrowBuffer(new BytePointer(indArray.data().pointer()), indArray.data().length())); + } + + Python.setContext(origContext); + Python.deleteContext(tempContext); + FieldVector fieldVector = new FieldVector(fields); + Schema schema = new Schema(fieldVector); + ArrayVector arrayVector = new ArrayVector(arrays); + Table ret = Table.Make(schema, arrayVector); + pySchema.del(); + pyShemaSize.del(); + tableType.del(); + pyarrow.del(); + return ret; + + } + + public static PythonObject getPyArrowTable(DataVecTable table) throws PythonException{ + + PythonObject d = Python.dict(); + for(int i = 0; i < table.numColumns(); i++){ + PythonObject colName = new PythonObject(table.columnNameAt(i)); + PythonObject colArr = new PythonObject(table.column(i).toNdArray()); + d.set(colName, colArr); + colName.del(); + colArr.del(); + } + PythonObject pyarrow = importPyArraow(); + PythonObject tableF = pyarrow.attr("table"); + PythonObject pyTable = tableF.call(d); + pyarrow.del(); + tableF.del(); + return pyTable; + } + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index e2d2e57477da..6359d444a9a3 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -20,6 +20,8 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; +import org.bytedeco.cpython.global.python; +import org.bytedeco.javacpp.Loader; import org.bytedeco.numpy.global.numpy; import org.nd4j.linalg.io.ClassPathResource; @@ -340,6 +342,18 @@ public static synchronized void initPythonPath() { if (path == null) { log.info("Setting python default path"); File[] packages = numpy.cachePackages(); + + //// TODO: fix in javacpp + File sitePackagesWindows = new File(python.cachePackage(), "site-packages"); + File[] packages2 = new File[packages.length + 1]; + for (int i = 0;i < packages.length; i++){ + System.out.println(packages[i].getAbsolutePath()); + packages2[i] = packages[i]; + } + packages2[packages.length] = sitePackagesWindows; + System.out.println(sitePackagesWindows.getAbsolutePath()); + packages = packages2; + ////////// Py_SetPath(packages); } else { log.info("Setting python path " + path); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java new file mode 100644 index 000000000000..064933440244 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java @@ -0,0 +1,103 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.IOUtils; +import org.bytedeco.javacpp.Loader; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +@Slf4j +public class PythonProcess { + private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); + public static String runAndReturn(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + log.info("Executing command: " + Arrays.toString(allArgs)); + ProcessBuilder pb = new ProcessBuilder(allArgs); + Process process = pb.start(); + String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8); + process.waitFor(); + return out; + + } + + public static void run(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + log.info("Executing command: " + Arrays.toString(allArgs)); + ProcessBuilder pb = new ProcessBuilder(allArgs); + pb.inheritIO().start().waitFor(); + } + public static void pipInstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "install", packageName); + }catch(Exception e){ + throw new PythonException("Error installing package " + packageName, e); + } + + } + + public static void pipInstall(String packageName, String version) throws PythonException{ + pipInstall(packageName + "==" + version); + } + + public static void pipUninstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "uninstall", packageName); + }catch(Exception e){ + throw new PythonException("Error uninstalling package " + packageName, e); + } + + } + public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{ + if (!gitRepoUrl.contains("://")){ + gitRepoUrl = "git://" + gitRepoUrl; + } + try{ + run("-m", "pip", "install", "git+", gitRepoUrl); + }catch(Exception e){ + throw new PythonException("Error installing package from " +gitRepoUrl, e); + } + + } + + public static String getPackageVersion(String packageName) throws IOException, InterruptedException, PythonException{ + String out = runAndReturn("-m", "pip", "show", packageName); + if (!out.contains("Version: ")){ + throw new PythonException("Can't find package " + packageName); + } + String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0]; + return pkgVersion; + } + + public static boolean isPackageInstalled(String packageName)throws IOException, InterruptedException{ + String out = runAndReturn("-m", "pip", "show", packageName); + return !out.isEmpty(); + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java new file mode 100644 index 000000000000..2bfce78695b7 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java @@ -0,0 +1,263 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + +import org.bytedeco.arrow.Field; +import org.bytedeco.arrow.FieldVector; +import org.bytedeco.arrow.Schema; +import org.bytedeco.arrow.Table; +import org.datavec.api.transform.ColumnType; +import org.datavec.arrow.table.DataVecTable; +import org.datavec.arrow.table.column.DataVecColumn; +import org.datavec.arrow.table.column.impl.*; +import org.datavec.arrow.table.row.Row; +import org.junit.Assert; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import org.junit.Before; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.bytedeco.arrow.global.arrow.*; +public class TestPythonArrowUtils { + + @Test + public void testPyArrowImport() throws Exception{ + PythonArrowUtils.importPyArraow(); + } + + @Test + public void testINDArrayConversion() throws PythonException{ + INDArray array = Nd4j.rand(10); + PythonObject pyArrowArray = PythonArrowUtils.getPyArrowArrayFromINDArray(array); + INDArray array2 = PythonArrowUtils.getINDArrayFromPyArrowArray(pyArrowArray); + Assert.assertEquals(array, array2); + + // test no copy + Assert.assertEquals(array.data().address(), array2.data().address()); + array.putScalar(0, Nd4j.rand(1).getDouble(0)); + Assert.assertEquals(array, array2); + } + + @Test + public void testFieldConversion() throws PythonException{ + Field[] fields = new Field[]{ + new Field("a", int8()), + new Field("b", int16()), + new Field("c", int32()), + new Field("d", int64()), + new Field("e", uint8()), + new Field("f", uint16()), + new Field("g", uint32()), + new Field("h", uint64()), + new Field("i", _boolean()), + new Field("j", float16()), + new Field("k", float32()), + new Field("l", float64()), + new Field("m", binary()), + }; + for (Field field: fields){ + PythonObject pyArrowField = PythonArrowUtils.getPyArrowField(field); + Field field2 = PythonArrowUtils.getFieldFromPythonObject(pyArrowField); + Assert.assertEquals(field.name(), field2.name()); + Assert.assertEquals(field.type(), field.type()); + } + } + + @Test + public void testSchemaConversion() throws PythonException{ + Field[] fields = new Field[]{ + new Field("a", int8()), + new Field("b", int16()), + new Field("c", int32()), + new Field("d", int64()), + new Field("e", uint8()), + new Field("f", uint16()), + new Field("g", uint32()), + new Field("h", uint64()), + new Field("i", _boolean()), + new Field("j", float16()), + new Field("k", float32()), + new Field("l", float64()), + new Field("m", binary()), + }; + Schema schema = new Schema(new FieldVector(fields)); + PythonObject pySchema = PythonArrowUtils.getPyArrowSchema(schema); + Schema schema2 = PythonArrowUtils.getSchemaFromPythonObject(pySchema); + Field[] fields2 = schema2.fields().get(); + Assert.assertEquals(fields.length, fields2.length); + for(int i = 0;i < fields.length; i++){ + Assert.assertEquals(fields[i].name(), fields2[i].name()); + Assert.assertEquals(fields[i].type(), fields2[i].type()); + } + } + + @Test + public void testTableConversion() throws PythonException{ + new DoubleColumn("double", new Double[]{1.0}); + + } + + @Test + public void testTableFromDict()throws Exception{ + PythonArrowUtils.init(); + Map map = new HashMap<>(); + map.put("a", Nd4j.zeros(5)); + map.put("b", Nd4j.ones(5)); + map.put("c", Nd4j.rand(5)); + + PythonObject d = new PythonObject(map); + + Table table = PythonArrowUtils.getTableFromPythonObject(d); + } + + @Test + public void testPyArrowTable() throws Exception{ + PythonArrowUtils.init(); + ColumnType[] columnTypes = new ColumnType[] { + ColumnType.Integer, + ColumnType.Double, + ColumnType.Float, + ColumnType.Boolean, + ColumnType.String + }; + + DataVecColumn[] dataVecColumns = new DataVecColumn[columnTypes.length]; + for (int i = 0; i < columnTypes.length; i++){ + ColumnType columnType = columnTypes[i]; + switch(columnType) { + case Double: + dataVecColumns[i] = new DoubleColumn(columnType.name().toLowerCase(),new Double[]{1.0}); + break; + case Float: + dataVecColumns[i] = new FloatColumn(columnType.name().toLowerCase(),new Float[]{1.0f}); + break; + case Boolean: + dataVecColumns[i] = new BooleanColumn(columnType.name().toLowerCase(),new Boolean[]{true}); + break; + case String: + dataVecColumns[i] = new StringColumn(columnType.name().toLowerCase(),new String[]{"1.0"}); + break; + case Long: + dataVecColumns[i] = new LongColumn(columnType.name().toLowerCase(),new Long[]{1L}); + break; + case Integer: + dataVecColumns[i] = new IntColumn(columnType.name().toUpperCase(),new Integer[]{1}); + break; + + } + } + DataVecTable dataVecTable = DataVecTable.create(dataVecColumns); + PythonObject pyArrowTable = PythonArrowUtils.getPyArrowTable(dataVecTable); + + + } + + @Test + public void testTable() { + int count = 0; + ColumnType[] columnTypes = new ColumnType[] { + ColumnType.Integer, + ColumnType.Double, + ColumnType.Float, + ColumnType.Boolean, + ColumnType.String + }; + + DataVecColumn[] dataVecColumns = new DataVecColumn[columnTypes.length]; + DataVecColumn[] dataVecColumnsList = new DataVecColumn[columnTypes.length]; + + for(ColumnType columnType : columnTypes) { + switch(columnType) { + case Double: + dataVecColumns[count] = new DoubleColumn(columnType.name().toLowerCase(),new Double[]{1.0}); + dataVecColumnsList[count] = new DoubleColumn(columnType.name().toLowerCase(), Arrays.asList(1.0)); + break; + case Float: + dataVecColumns[count] = new FloatColumn(columnType.name().toLowerCase(),new Float[]{1.0f}); + dataVecColumnsList[count] = new FloatColumn(columnType.name().toLowerCase(),Arrays.asList(1.0f)); + break; + case Boolean: + dataVecColumns[count] = new BooleanColumn(columnType.name().toLowerCase(),new Boolean[]{true}); + dataVecColumnsList[count] = new BooleanColumn(columnType.name().toLowerCase(),Arrays.asList(true)); + break; + case String: + dataVecColumns[count] = new StringColumn(columnType.name().toLowerCase(),new String[]{"1.0"}); + dataVecColumnsList[count] = new StringColumn(columnType.name().toLowerCase(),Arrays.asList("1.0")); + break; + case Long: + dataVecColumns[count] = new LongColumn(columnType.name().toLowerCase(),new Long[]{1L}); + dataVecColumnsList[count] = new LongColumn(columnType.name().toLowerCase(),Arrays.asList(1L)); + break; + case Integer: + dataVecColumns[count] = new IntColumn(columnType.name().toUpperCase(),new Integer[]{1}); + dataVecColumnsList[count] = new IntColumn(columnType.name().toUpperCase(),Arrays.asList(1)); + break; + + } + + assertEquals(1,dataVecColumns[count].rows()); + assertEquals("Column type of " + columnType + " has wrong number of rows",1,dataVecColumns[count].rows()); + count++; + } + + DataVecTable dataVecTable1 = DataVecTable.create(dataVecColumns); + assertEquals(columnTypes.length,dataVecTable1.numColumns()); + DataVecTable dataVecTableList = DataVecTable.create(dataVecColumnsList); + + Row row = dataVecTable1.row(0); + Row row2 = dataVecTableList.row(0); + assertEquals(1.0d, row.elementAtColumn("double"),1e-3); + assertEquals(1.0f, row.elementAtColumn("float"),1e-3f); + assertEquals("1.0",row.elementAtColumn("string")); + assertEquals(true, row.elementAtColumn("boolean")); + + assertEquals(1.0d, row2.elementAtColumn("double"),1e-3); + assertEquals(1.0f, row2.elementAtColumn("float"),1e-3f); + assertEquals("1.0",row2.elementAtColumn("string")); + assertEquals(true, row2.elementAtColumn("boolean")); + + + + for(int i = 0; i < row.columnNames().size(); i++) { + assertTrue(row.elementAtColumn(i).equals(row.elementAtColumn(row.columnNames().get(i)))); + assertTrue(dataVecTable1.column(i).contains(row.elementAtColumn(i))); + INDArray arr2 = dataVecTable1.column(i).toNdArray(); + assertEquals(dataVecTable1.column(1).rows(),arr2.length()); + List list = dataVecTable1.column(i).toList(); + assertEquals(1,list.size()); + + + assertTrue(row2.elementAtColumn(i).equals(row2.elementAtColumn(row2.columnNames().get(i)))); + assertTrue(dataVecTableList.column(i).contains(row2.elementAtColumn(i))); + arr2 = dataVecTableList.column(i).toNdArray(); + assertEquals(dataVecTableList.column(1).rows(),arr2.length()); + list = dataVecTableList.column(i).toList(); + assertEquals(1,list.size()); + + } + + assertEquals(1,dataVecTable1.numRows()); + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java new file mode 100644 index 000000000000..7443b4dddc29 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java @@ -0,0 +1,24 @@ +package org.datavec.python; + +import org.junit.Assert; +import org.junit.Test; + +public class TestPythonProcess { + + @Test + public void testPythonProcess() throws Exception{ + String stdout = PythonProcess.runAndReturn("-m", "pip", "list"); + System.out.println(stdout); + Assert.assertTrue(stdout.replace(" ", "").contains("PackageVersion")); + } + @Test + public void testPackageVersion() throws Exception{ + System.out.println(PythonProcess.getPackageVersion("numpy")); + } + + @Test + public void testPackageInstalledCheck() throws Exception{ + Assert.assertFalse(PythonProcess.isPackageInstalled("abcdefgh")); + } + +} diff --git a/libnd4j/pom.xml b/libnd4j/pom.xml index 20b9d65623be..e7c26f99a8a2 100644 --- a/libnd4j/pom.xml +++ b/libnd4j/pom.xml @@ -17,8 +17,8 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> org.deeplearning4j @@ -138,7 +138,7 @@ javacpp-cppbuild-validate validate - build + build