1 /* 2 Copyright (c) 2021 by Contributors 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package ml.dmlc.xgboost4j.gpu.java; 18 19 import ai.rapids.cudf.BaseDeviceMemoryBuffer; 20 import ai.rapids.cudf.BufferType; 21 import ai.rapids.cudf.ColumnVector; 22 import ai.rapids.cudf.DType; 23 24 import ml.dmlc.xgboost4j.java.Column; 25 26 /** 27 * This class is composing of base data with Apache Arrow format from Cudf ColumnVector. 28 * It will be used to generate the cuda array interface. 29 */ 30 class CudfColumn extends Column { 31 32 private final long dataPtr; // gpu data buffer address 33 private final long shape; // row count 34 private final long validPtr; // gpu valid buffer address 35 private final int typeSize; // type size in bytes 36 private final String typeStr; // follow array interface spec 37 private final long nullCount; // null count 38 39 private String arrayInterface = null; // the cuda array interface 40 from(ColumnVector cv)41 public static CudfColumn from(ColumnVector cv) { 42 BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA); 43 BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY); 44 long validPtr = 0; 45 if (validBuffer != null) { 46 validPtr = validBuffer.getAddress(); 47 } 48 DType dType = cv.getType(); 49 String typeStr = ""; 50 if (dType == DType.FLOAT32 || dType == DType.FLOAT64 || 51 dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS || 52 dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS || 53 dType == DType.TIMESTAMP_SECONDS) { 54 typeStr = "<f" + dType.getSizeInBytes(); 55 } else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 || 56 dType == DType.INT32 || dType == DType.INT64) { 57 typeStr = "<i" + dType.getSizeInBytes(); 58 } else { 59 // Unsupported type. 60 throw new IllegalArgumentException("Unsupported data type: " + dType); 61 } 62 63 return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr, 64 dType.getSizeInBytes(), typeStr, cv.getNullCount()); 65 } 66 CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr, long nullCount)67 private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr, 68 long nullCount) { 69 this.dataPtr = dataPtr; 70 this.shape = shape; 71 this.validPtr = validPtr; 72 this.typeSize = typeSize; 73 this.typeStr = typeStr; 74 this.nullCount = nullCount; 75 } 76 77 @Override getArrayInterfaceJson()78 public String getArrayInterfaceJson() { 79 // There is no race-condition 80 if (arrayInterface == null) { 81 arrayInterface = CudfUtils.buildArrayInterface(this); 82 } 83 return arrayInterface; 84 } 85 getDataPtr()86 public long getDataPtr() { 87 return dataPtr; 88 } 89 getShape()90 public long getShape() { 91 return shape; 92 } 93 getValidPtr()94 public long getValidPtr() { 95 return validPtr; 96 } 97 getTypeSize()98 public int getTypeSize() { 99 return typeSize; 100 } 101 getTypeStr()102 public String getTypeStr() { 103 return typeStr; 104 } 105 getNullCount()106 public long getNullCount() { 107 return nullCount; 108 } 109 110 } 111