diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml
index 55cf6c5dad1e..e0f5ccd437da 100644
--- a/datavec/datavec-python/pom.xml
+++ b/datavec/datavec-python/pom.xml
@@ -42,17 +42,21 @@
numpy-platform
${numpy.javacpp.version}
-
- com.google.code.findbugs
- jsr305
- 3.0.2
+ org.nd4j
+ nd4j-arrow
+ ${nd4j.version}
org.datavec
datavec-api
${project.version}
+
+ org.datavec
+ datavec-arrow
+ ${project.version}
+
ch.qos.logback
logback-classic
@@ -74,4 +78,17 @@
test-nd4j-cuda-10.2
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+ 1.8
+ 1.8
+
+
+
+
diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java
index 9dabbef2d330..93b1c8622a21 100644
--- a/datavec/datavec-python/src/main/java/org/datavec/python/Python.java
+++ b/datavec/datavec-python/src/main/java/org/datavec/python/Python.java
@@ -112,12 +112,20 @@ public static PythonObject listType() {
return attr("list");
}
+ public static PythonObject list(PythonObject[] pythonObjects){
+ PyObject list = PyList_New(pythonObjects.length);
+ for(int i = 0;i < pythonObjects.length; i++){
+ PyList_SetItem(list, i, pythonObjects[i].getNativePythonObject());
+ }
+ return new PythonObject(list);
+ }
+
public static PythonObject dict(PythonObject pythonObject) {
return attr("dict").call(pythonObject);
}
public static PythonObject dict() {
- return attr("dict").call();
+ return new PythonObject(PyDict_New());
}
public static PythonObject dictType() {
@@ -172,6 +180,14 @@ public static PythonObject tuple() {
return attr("tuple").call();
}
+ public static PythonObject tuple(PythonObject[] pythonObjects){
+ PyObject tuple = PyTuple_New(pythonObjects.length);
+ for(int i = 0;i < pythonObjects.length; i++){
+ PyTuple_SetItem(tuple, i, pythonObjects[i].getNativePythonObject());
+ }
+ return new PythonObject(tuple);
+ }
+
public static PythonObject Exception(PythonObject pythonObject) {
return attr("Exception").call(pythonObject);
@@ -266,5 +282,11 @@ public static PythonGIL lock(){
return PythonGIL.lock();
}
+ public static void setVariable(String name, PythonObject value) throws PythonException{
+ PythonExecutioner.setVariable(name, value);
+ }
+ public static PythonObject getVariable(String name){
+ return PythonExecutioner.getVariable(name);
+ }
}
diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java
new file mode 100644
index 000000000000..83df7fcf4f77
--- /dev/null
+++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonArrowUtils.java
@@ -0,0 +1,296 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+
+package org.datavec.python;
+
+import lombok.extern.slf4j.Slf4j;
+import org.bytedeco.arrow.*;
+import org.bytedeco.javacpp.BytePointer;
+import org.bytedeco.javacpp.Loader;
+import org.datavec.arrow.table.DataVecTable;
+import org.nd4j.arrow.ByteDecoArrowSerde;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import java.util.UUID;
+
+import static org.bytedeco.arrow.global.arrow.*;
+
+@Slf4j
+public class PythonArrowUtils {
+
+ static {
+ init();
+ }
+ private static String PYARROW = "pyarrow";
+ private static String PANDAS = "pandas";
+ private static String REQUIRED_PYARROW_VERSION = "0.15.1"; // TODO get version from pom.xml?
+ private static boolean init = false;
+
+ public static void init(){
+ // TODO: Find out why this works
+ INDArray dummyArr = Nd4j.rand(5);
+ new DoubleArray(new ArrowBuffer(new BytePointer(dummyArr.data().pointer()), dummyArr.data().length()));
+ }
+
+ public static PythonObject importPyArraow() throws PythonException{
+ try{
+ if (!PythonProcess.isPackageInstalled(PYARROW)){
+ log.info("PyArrow is not installed. Attempting to pip install pyarrow " + REQUIRED_PYARROW_VERSION);
+ PythonProcess.pipInstall(PYARROW, REQUIRED_PYARROW_VERSION);
+ PythonProcess.pipInstall(PANDAS);
+ }
+ else{
+ String pkgVersion = PythonProcess.getPackageVersion(PYARROW);
+ if (!pkgVersion.equals(REQUIRED_PYARROW_VERSION)) {
+ log.info("Required pyarrow version " + REQUIRED_PYARROW_VERSION + " but current version is " + pkgVersion + ". Attempting reinstall...");
+ PythonProcess.pipInstall(PYARROW, REQUIRED_PYARROW_VERSION);
+ PythonProcess.pipInstall(PANDAS);
+ }
+ }
+ } catch(Exception e){
+ throw new PythonException("Error verifying/installing pyarrow package.", e);
+ }
+
+ return Python.importModule(PYARROW);
+ }
+
+ public static PythonObject getPyArrowArrayFromINDArray(INDArray arr) throws PythonException{
+ PythonObject pyarrow = importPyArraow();
+ PythonObject npArr = new PythonObject(arr);
+ PythonObject arrayF = pyarrow.attr("array");
+ PythonObject ret = arrayF.call(npArr);
+ pyarrow.del();
+ npArr.del();
+ arrayF.del();
+ return ret;
+ }
+
+ public static INDArray getINDArrayFromPyArrowArray(PythonObject arr) throws PythonException{
+ PythonObject pyArrow = Python.importModule(PYARROW);
+ PythonObject arrayType = pyArrow.attr("Array");
+ if (!Python.isinstance(arr, arrayType)){
+ pyArrow.del();
+ arrayType.del();
+ throw new PythonException("Expected pyarrow.Array, received " + Python.type(arr));
+ }
+ PythonObject toNumpyF = arr.attr("to_numpy");
+ PythonObject npArr = toNumpyF.call();
+ pyArrow.del();
+ arrayType.del();
+ toNumpyF.del();
+ INDArray ret = npArr.toNumpy().getNd4jArray();
+ npArr.del();
+ return ret;
+ }
+
+ public static PythonObject getPyArrowField(Field field) throws PythonException{
+ String name = field.name();
+ org.bytedeco.arrow.DataType type = field.type();
+ String typeName = type.name();
+ String pyTypeFName;
+ switch (typeName){
+ case "list":
+ throw new PythonException("Unsupported field type: list");
+ case "bool":
+ pyTypeFName = "bool_";
+ break;
+ case "halffloat":
+ pyTypeFName = "float16";
+ break;
+ case "float":
+ pyTypeFName = "float32";
+ break;
+ case "double":
+ pyTypeFName = "float64";
+ break;
+ default:
+ pyTypeFName = typeName;
+ }
+ PythonObject pyarrow = importPyArraow();
+ PythonObject fieldF = pyarrow.attr("field");
+ PythonObject pyArrowTypeF = pyarrow.attr(pyTypeFName);
+ PythonObject pyArrowType = pyArrowTypeF.call();
+ PythonObject pyArrowField = fieldF.call(name, pyArrowType);
+ pyArrowType.del();
+ pyArrowTypeF.del();
+ fieldF.del();
+ pyarrow.del();
+ return pyArrowField;
+ }
+
+ public static Field getFieldFromPythonObject(PythonObject pyArrowField) throws PythonException{
+ PythonObject pyarrow = importPyArraow();
+ PythonObject fieldType = pyarrow.attr("Field");
+ if(!Python.isinstance(pyArrowField, fieldType)){
+ pyarrow.del();
+ fieldType.del();
+ throw new PythonException("Expected pyarrow.Field, received " + Python.type(pyArrowField));
+ }
+ PythonObject pyName = pyArrowField.attr("name");
+ String name = pyName.toString();
+ PythonObject pyTypeName = pyArrowField.attr("type");
+ String typeName = pyTypeName.toString();
+ DataType dt;
+ switch (typeName){
+ case "bool":
+ dt = _boolean();
+ break;
+ case "halffloat":
+ dt = float16();
+ break;
+ case "float":
+ dt = float32();
+ break;
+ case "double":
+ dt = float64();
+ break;
+ default:
+ try{
+ dt = (DataType)org.bytedeco.arrow.global.arrow.class.getMethod(typeName).invoke(null);
+ }
+ catch (Exception e){
+ throw new PythonException("Unsupported type: " + typeName, e);
+ }
+ }
+ Field ret = new Field(name, dt);
+ pyarrow.del();
+ fieldType.del();
+ pyName.del();
+ pyTypeName.del();
+ return ret;
+ }
+
+ public static PythonObject getPyArrowSchema(Schema schema) throws PythonException{
+ Field[] fields = schema.fields().get();
+ PythonObject[] pyFields = new PythonObject[fields.length];
+ for (int i = 0; i < fields.length; i++){
+ pyFields[i] = getPyArrowField(fields[i]);
+ }
+ PythonObject pyarrow = importPyArraow();
+ PythonObject schemaF = pyarrow.attr("schema");
+ PythonObject pySchema = schemaF.call(Python.list(pyFields));
+ pyarrow.del();
+ schemaF.del();
+ return pySchema;
+ }
+
+ public static Schema getSchemaFromPythonObject(PythonObject pyArrowSchema) throws PythonException{
+ PythonObject pyarrow = importPyArraow();
+ PythonObject schemaType = pyarrow.attr("Schema");
+ if(!Python.isinstance(pyArrowSchema, schemaType)){
+ pyarrow.del();
+ schemaType.del();
+ throw new PythonException("Expected pyarrow.Field, received " + Python.type(pyArrowSchema));
+ }
+ PythonObject pySize = Python.len(pyArrowSchema);
+ int size = pySize.toInt();
+ Field[] fields = new Field[size];
+ for(int i = 0; i < size; i++){
+ PythonObject pyField = pyArrowSchema.get(i);
+ fields[i] = getFieldFromPythonObject(pyField);
+ }
+ pySize.del();
+ schemaType.del();
+ pyarrow.del();
+ return new Schema(new FieldVector(fields));
+ }
+
+
+ public static PythonObject getPyArrowTable(Table table) throws PythonException{
+ PythonObject d = Python.dict();
+ Schema schema =table.schema();
+ Field[] fields = schema.fields().get();
+ for (int i = 0; i < fields.length; i++){
+ String colName = fields[i].name();
+ PythonObject pyColName = new PythonObject(colName);
+ ChunkedArray chunkedArray = table.column(i);
+ INDArray arr = Nd4j.create(ByteDecoArrowSerde.fromArrowBuffer(chunkedArray.chunk(0).null_bitmap(), fields[i].type()));
+ PythonObject pyArr = new PythonObject(arr);
+ d.set(pyColName, pyArr);
+ }
+ PythonObject pyarrow = importPyArraow();
+ PythonObject tableF = pyarrow.attr("table");
+ PythonObject pyTable = tableF.call(d);
+ pyarrow.del();
+ tableF.del();
+ return pyTable;
+ }
+
+ public static Table getTableFromPythonObject(PythonObject pyTable) throws PythonException{
+ PythonObject pyarrow = importPyArraow();
+ PythonObject tableType = pyarrow.attr("Table");
+ if (!Python.isinstance(pyTable, tableType)){
+ if (Python.isinstance(pyTable, Python.dictType())){
+ PythonObject orig = pyTable;
+ PythonObject tableF = pyarrow.attr("table");
+ pyTable = tableF.call(pyTable);
+ orig.del();
+ tableF.del();
+ }
+ else {
+ throw new PythonException("Expected pyarrow.lib.Table or dict, received " + Python.type(pyTable));
+ }
+ }
+ PythonObject pySchema = pyTable.attr("schema");
+ PythonObject pyShemaSize = Python.len(pySchema);
+ Field[] fields = new Field[pyShemaSize.toInt()];
+ Array[] arrays = new FlatArray[fields.length];
+ String origContext = Python.getCurrentContext();
+ String tempContext = 'a' + UUID.randomUUID().toString().replace('-','_' );
+ Python.setContext(tempContext);
+ for (int i = 0; i < fields.length; i++){
+ Python.setVariable("col", pyTable.get(i));
+ Python.exec("arr=col.to_pandas().to_numpy()");
+ INDArray indArray = Python.getVariable("arr").toNumpy().getNd4jArray();
+ fields[i] = getFieldFromPythonObject(pySchema.get(i));
+ arrays[i] = new DoubleArray(new ArrowBuffer(new BytePointer(indArray.data().pointer()), indArray.data().length()));
+ }
+
+ Python.setContext(origContext);
+ Python.deleteContext(tempContext);
+ FieldVector fieldVector = new FieldVector(fields);
+ Schema schema = new Schema(fieldVector);
+ ArrayVector arrayVector = new ArrayVector(arrays);
+ Table ret = Table.Make(schema, arrayVector);
+ pySchema.del();
+ pyShemaSize.del();
+ tableType.del();
+ pyarrow.del();
+ return ret;
+
+ }
+
+ public static PythonObject getPyArrowTable(DataVecTable table) throws PythonException{
+
+ PythonObject d = Python.dict();
+ for(int i = 0; i < table.numColumns(); i++){
+ PythonObject colName = new PythonObject(table.columnNameAt(i));
+ PythonObject colArr = new PythonObject(table.column(i).toNdArray());
+ d.set(colName, colArr);
+ colName.del();
+ colArr.del();
+ }
+ PythonObject pyarrow = importPyArraow();
+ PythonObject tableF = pyarrow.attr("table");
+ PythonObject pyTable = tableF.call(d);
+ pyarrow.del();
+ tableF.del();
+ return pyTable;
+ }
+
+}
diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java
index e2d2e57477da..6359d444a9a3 100644
--- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java
+++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java
@@ -20,6 +20,8 @@
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
+import org.bytedeco.cpython.global.python;
+import org.bytedeco.javacpp.Loader;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.io.ClassPathResource;
@@ -340,6 +342,18 @@ public static synchronized void initPythonPath() {
if (path == null) {
log.info("Setting python default path");
File[] packages = numpy.cachePackages();
+
+ //// TODO: fix in javacpp
+ File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
+ File[] packages2 = new File[packages.length + 1];
+ for (int i = 0;i < packages.length; i++){
+ System.out.println(packages[i].getAbsolutePath());
+ packages2[i] = packages[i];
+ }
+ packages2[packages.length] = sitePackagesWindows;
+ System.out.println(sitePackagesWindows.getAbsolutePath());
+ packages = packages2;
+ //////////
Py_SetPath(packages);
} else {
log.info("Setting python path " + path);
diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java
new file mode 100644
index 000000000000..064933440244
--- /dev/null
+++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java
@@ -0,0 +1,103 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+
+package org.datavec.python;
+
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.io.IOUtils;
+import org.bytedeco.javacpp.Loader;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+
+@Slf4j
+public class PythonProcess {
+ private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
+ public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
+ String[] allArgs = new String[arguments.length + 1];
+ for (int i = 0; i < arguments.length; i++){
+ allArgs[i + 1] = arguments[i];
+ }
+ allArgs[0] = pythonExecutable;
+ log.info("Executing command: " + Arrays.toString(allArgs));
+ ProcessBuilder pb = new ProcessBuilder(allArgs);
+ Process process = pb.start();
+ String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
+ process.waitFor();
+ return out;
+
+ }
+
+ public static void run(String... arguments)throws IOException, InterruptedException{
+ String[] allArgs = new String[arguments.length + 1];
+ for (int i = 0; i < arguments.length; i++){
+ allArgs[i + 1] = arguments[i];
+ }
+ allArgs[0] = pythonExecutable;
+ log.info("Executing command: " + Arrays.toString(allArgs));
+ ProcessBuilder pb = new ProcessBuilder(allArgs);
+ pb.inheritIO().start().waitFor();
+ }
+ public static void pipInstall(String packageName) throws PythonException{
+ try{
+ run("-m", "pip", "install", packageName);
+ }catch(Exception e){
+ throw new PythonException("Error installing package " + packageName, e);
+ }
+
+ }
+
+ public static void pipInstall(String packageName, String version) throws PythonException{
+ pipInstall(packageName + "==" + version);
+ }
+
+ public static void pipUninstall(String packageName) throws PythonException{
+ try{
+ run("-m", "pip", "uninstall", packageName);
+ }catch(Exception e){
+ throw new PythonException("Error uninstalling package " + packageName, e);
+ }
+
+ }
+ public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{
+ if (!gitRepoUrl.contains("://")){
+ gitRepoUrl = "git://" + gitRepoUrl;
+ }
+ try{
+ run("-m", "pip", "install", "git+", gitRepoUrl);
+ }catch(Exception e){
+ throw new PythonException("Error installing package from " +gitRepoUrl, e);
+ }
+
+ }
+
+ public static String getPackageVersion(String packageName) throws IOException, InterruptedException, PythonException{
+ String out = runAndReturn("-m", "pip", "show", packageName);
+ if (!out.contains("Version: ")){
+ throw new PythonException("Can't find package " + packageName);
+ }
+ String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
+ return pkgVersion;
+ }
+
+ public static boolean isPackageInstalled(String packageName)throws IOException, InterruptedException{
+ String out = runAndReturn("-m", "pip", "show", packageName);
+ return !out.isEmpty();
+ }
+
+}
diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java
new file mode 100644
index 000000000000..2bfce78695b7
--- /dev/null
+++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonArrowUtils.java
@@ -0,0 +1,263 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.datavec.python;
+
+import org.bytedeco.arrow.Field;
+import org.bytedeco.arrow.FieldVector;
+import org.bytedeco.arrow.Schema;
+import org.bytedeco.arrow.Table;
+import org.datavec.api.transform.ColumnType;
+import org.datavec.arrow.table.DataVecTable;
+import org.datavec.arrow.table.column.DataVecColumn;
+import org.datavec.arrow.table.column.impl.*;
+import org.datavec.arrow.table.row.Row;
+import org.junit.Assert;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import org.junit.Before;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.bytedeco.arrow.global.arrow.*;
+public class TestPythonArrowUtils {
+
+ @Test
+ public void testPyArrowImport() throws Exception{
+ PythonArrowUtils.importPyArraow();
+ }
+
+ @Test
+ public void testINDArrayConversion() throws PythonException{
+ INDArray array = Nd4j.rand(10);
+ PythonObject pyArrowArray = PythonArrowUtils.getPyArrowArrayFromINDArray(array);
+ INDArray array2 = PythonArrowUtils.getINDArrayFromPyArrowArray(pyArrowArray);
+ Assert.assertEquals(array, array2);
+
+ // test no copy
+ Assert.assertEquals(array.data().address(), array2.data().address());
+ array.putScalar(0, Nd4j.rand(1).getDouble(0));
+ Assert.assertEquals(array, array2);
+ }
+
+ @Test
+ public void testFieldConversion() throws PythonException{
+ Field[] fields = new Field[]{
+ new Field("a", int8()),
+ new Field("b", int16()),
+ new Field("c", int32()),
+ new Field("d", int64()),
+ new Field("e", uint8()),
+ new Field("f", uint16()),
+ new Field("g", uint32()),
+ new Field("h", uint64()),
+ new Field("i", _boolean()),
+ new Field("j", float16()),
+ new Field("k", float32()),
+ new Field("l", float64()),
+ new Field("m", binary()),
+ };
+ for (Field field: fields){
+ PythonObject pyArrowField = PythonArrowUtils.getPyArrowField(field);
+ Field field2 = PythonArrowUtils.getFieldFromPythonObject(pyArrowField);
+ Assert.assertEquals(field.name(), field2.name());
+ Assert.assertEquals(field.type(), field.type());
+ }
+ }
+
+ @Test
+ public void testSchemaConversion() throws PythonException{
+ Field[] fields = new Field[]{
+ new Field("a", int8()),
+ new Field("b", int16()),
+ new Field("c", int32()),
+ new Field("d", int64()),
+ new Field("e", uint8()),
+ new Field("f", uint16()),
+ new Field("g", uint32()),
+ new Field("h", uint64()),
+ new Field("i", _boolean()),
+ new Field("j", float16()),
+ new Field("k", float32()),
+ new Field("l", float64()),
+ new Field("m", binary()),
+ };
+ Schema schema = new Schema(new FieldVector(fields));
+ PythonObject pySchema = PythonArrowUtils.getPyArrowSchema(schema);
+ Schema schema2 = PythonArrowUtils.getSchemaFromPythonObject(pySchema);
+ Field[] fields2 = schema2.fields().get();
+ Assert.assertEquals(fields.length, fields2.length);
+ for(int i = 0;i < fields.length; i++){
+ Assert.assertEquals(fields[i].name(), fields2[i].name());
+ Assert.assertEquals(fields[i].type(), fields2[i].type());
+ }
+ }
+
+ @Test
+ public void testTableConversion() throws PythonException{
+ new DoubleColumn("double", new Double[]{1.0});
+
+ }
+
+ @Test
+ public void testTableFromDict()throws Exception{
+ PythonArrowUtils.init();
+ Map map = new HashMap<>();
+ map.put("a", Nd4j.zeros(5));
+ map.put("b", Nd4j.ones(5));
+ map.put("c", Nd4j.rand(5));
+
+ PythonObject d = new PythonObject(map);
+
+ Table table = PythonArrowUtils.getTableFromPythonObject(d);
+ }
+
+ @Test
+ public void testPyArrowTable() throws Exception{
+ PythonArrowUtils.init();
+ ColumnType[] columnTypes = new ColumnType[] {
+ ColumnType.Integer,
+ ColumnType.Double,
+ ColumnType.Float,
+ ColumnType.Boolean,
+ ColumnType.String
+ };
+
+ DataVecColumn[] dataVecColumns = new DataVecColumn[columnTypes.length];
+ for (int i = 0; i < columnTypes.length; i++){
+ ColumnType columnType = columnTypes[i];
+ switch(columnType) {
+ case Double:
+ dataVecColumns[i] = new DoubleColumn(columnType.name().toLowerCase(),new Double[]{1.0});
+ break;
+ case Float:
+ dataVecColumns[i] = new FloatColumn(columnType.name().toLowerCase(),new Float[]{1.0f});
+ break;
+ case Boolean:
+ dataVecColumns[i] = new BooleanColumn(columnType.name().toLowerCase(),new Boolean[]{true});
+ break;
+ case String:
+ dataVecColumns[i] = new StringColumn(columnType.name().toLowerCase(),new String[]{"1.0"});
+ break;
+ case Long:
+ dataVecColumns[i] = new LongColumn(columnType.name().toLowerCase(),new Long[]{1L});
+ break;
+ case Integer:
+ dataVecColumns[i] = new IntColumn(columnType.name().toUpperCase(),new Integer[]{1});
+ break;
+
+ }
+ }
+ DataVecTable dataVecTable = DataVecTable.create(dataVecColumns);
+ PythonObject pyArrowTable = PythonArrowUtils.getPyArrowTable(dataVecTable);
+
+
+ }
+
+ @Test
+ public void testTable() {
+ int count = 0;
+ ColumnType[] columnTypes = new ColumnType[] {
+ ColumnType.Integer,
+ ColumnType.Double,
+ ColumnType.Float,
+ ColumnType.Boolean,
+ ColumnType.String
+ };
+
+ DataVecColumn[] dataVecColumns = new DataVecColumn[columnTypes.length];
+ DataVecColumn[] dataVecColumnsList = new DataVecColumn[columnTypes.length];
+
+ for(ColumnType columnType : columnTypes) {
+ switch(columnType) {
+ case Double:
+ dataVecColumns[count] = new DoubleColumn(columnType.name().toLowerCase(),new Double[]{1.0});
+ dataVecColumnsList[count] = new DoubleColumn(columnType.name().toLowerCase(), Arrays.asList(1.0));
+ break;
+ case Float:
+ dataVecColumns[count] = new FloatColumn(columnType.name().toLowerCase(),new Float[]{1.0f});
+ dataVecColumnsList[count] = new FloatColumn(columnType.name().toLowerCase(),Arrays.asList(1.0f));
+ break;
+ case Boolean:
+ dataVecColumns[count] = new BooleanColumn(columnType.name().toLowerCase(),new Boolean[]{true});
+ dataVecColumnsList[count] = new BooleanColumn(columnType.name().toLowerCase(),Arrays.asList(true));
+ break;
+ case String:
+ dataVecColumns[count] = new StringColumn(columnType.name().toLowerCase(),new String[]{"1.0"});
+ dataVecColumnsList[count] = new StringColumn(columnType.name().toLowerCase(),Arrays.asList("1.0"));
+ break;
+ case Long:
+ dataVecColumns[count] = new LongColumn(columnType.name().toLowerCase(),new Long[]{1L});
+ dataVecColumnsList[count] = new LongColumn(columnType.name().toLowerCase(),Arrays.asList(1L));
+ break;
+ case Integer:
+ dataVecColumns[count] = new IntColumn(columnType.name().toUpperCase(),new Integer[]{1});
+ dataVecColumnsList[count] = new IntColumn(columnType.name().toUpperCase(),Arrays.asList(1));
+ break;
+
+ }
+
+ assertEquals(1,dataVecColumns[count].rows());
+ assertEquals("Column type of " + columnType + " has wrong number of rows",1,dataVecColumns[count].rows());
+ count++;
+ }
+
+ DataVecTable dataVecTable1 = DataVecTable.create(dataVecColumns);
+ assertEquals(columnTypes.length,dataVecTable1.numColumns());
+ DataVecTable dataVecTableList = DataVecTable.create(dataVecColumnsList);
+
+ Row row = dataVecTable1.row(0);
+ Row row2 = dataVecTableList.row(0);
+ assertEquals(1.0d, row.elementAtColumn("double"),1e-3);
+ assertEquals(1.0f, row.elementAtColumn("float"),1e-3f);
+ assertEquals("1.0",row.elementAtColumn("string"));
+ assertEquals(true, row.elementAtColumn("boolean"));
+
+ assertEquals(1.0d, row2.elementAtColumn("double"),1e-3);
+ assertEquals(1.0f, row2.elementAtColumn("float"),1e-3f);
+ assertEquals("1.0",row2.elementAtColumn("string"));
+ assertEquals(true, row2.elementAtColumn("boolean"));
+
+
+
+ for(int i = 0; i < row.columnNames().size(); i++) {
+ assertTrue(row.elementAtColumn(i).equals(row.elementAtColumn(row.columnNames().get(i))));
+ assertTrue(dataVecTable1.column(i).contains(row.elementAtColumn(i)));
+ INDArray arr2 = dataVecTable1.column(i).toNdArray();
+ assertEquals(dataVecTable1.column(1).rows(),arr2.length());
+ List list = dataVecTable1.column(i).toList();
+ assertEquals(1,list.size());
+
+
+ assertTrue(row2.elementAtColumn(i).equals(row2.elementAtColumn(row2.columnNames().get(i))));
+ assertTrue(dataVecTableList.column(i).contains(row2.elementAtColumn(i)));
+ arr2 = dataVecTableList.column(i).toNdArray();
+ assertEquals(dataVecTableList.column(1).rows(),arr2.length());
+ list = dataVecTableList.column(i).toList();
+ assertEquals(1,list.size());
+
+ }
+
+ assertEquals(1,dataVecTable1.numRows());
+ }
+
+}
diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java
new file mode 100644
index 000000000000..7443b4dddc29
--- /dev/null
+++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonProcess.java
@@ -0,0 +1,24 @@
+package org.datavec.python;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestPythonProcess {
+
+ @Test
+ public void testPythonProcess() throws Exception{
+ String stdout = PythonProcess.runAndReturn("-m", "pip", "list");
+ System.out.println(stdout);
+ Assert.assertTrue(stdout.replace(" ", "").contains("PackageVersion"));
+ }
+ @Test
+ public void testPackageVersion() throws Exception{
+ System.out.println(PythonProcess.getPackageVersion("numpy"));
+ }
+
+ @Test
+ public void testPackageInstalledCheck() throws Exception{
+ Assert.assertFalse(PythonProcess.isPackageInstalled("abcdefgh"));
+ }
+
+}
diff --git a/libnd4j/pom.xml b/libnd4j/pom.xml
index 20b9d65623be..e7c26f99a8a2 100644
--- a/libnd4j/pom.xml
+++ b/libnd4j/pom.xml
@@ -17,8 +17,8 @@
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
org.deeplearning4j
@@ -138,7 +138,7 @@
javacpp-cppbuild-validate
validate
- build
+ build