1 /*
2  Copyright (c) 2021 by Contributors
3 
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7 
8  http://www.apache.org/licenses/LICENSE-2.0
9 
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15  */
16 
17 package ml.dmlc.xgboost4j.gpu.java;
18 
19 import ai.rapids.cudf.BaseDeviceMemoryBuffer;
20 import ai.rapids.cudf.BufferType;
21 import ai.rapids.cudf.ColumnVector;
22 import ai.rapids.cudf.DType;
23 
24 import ml.dmlc.xgboost4j.java.Column;
25 
26 /**
27  * This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
28  * It will be used to generate the cuda array interface.
29  */
30 class CudfColumn extends Column {
31 
32   private final long dataPtr; //  gpu data buffer address
33   private final long shape;   // row count
34   private final long validPtr; // gpu valid buffer address
35   private final int typeSize; // type size in bytes
36   private final String typeStr; // follow array interface spec
37   private final long nullCount; // null count
38 
39   private String arrayInterface = null; // the cuda array interface
40 
from(ColumnVector cv)41   public static CudfColumn from(ColumnVector cv) {
42     BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
43     BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
44     long validPtr = 0;
45     if (validBuffer != null) {
46       validPtr = validBuffer.getAddress();
47     }
48     DType dType = cv.getType();
49     String typeStr = "";
50     if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
51         dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
52         dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
53         dType == DType.TIMESTAMP_SECONDS) {
54       typeStr = "<f" + dType.getSizeInBytes();
55     } else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
56         dType == DType.INT32 || dType == DType.INT64) {
57       typeStr = "<i" + dType.getSizeInBytes();
58     } else {
59       // Unsupported type.
60       throw new IllegalArgumentException("Unsupported data type: " + dType);
61     }
62 
63     return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
64       dType.getSizeInBytes(), typeStr, cv.getNullCount());
65   }
66 
CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr, long nullCount)67   private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
68                     long nullCount) {
69     this.dataPtr = dataPtr;
70     this.shape = shape;
71     this.validPtr = validPtr;
72     this.typeSize = typeSize;
73     this.typeStr = typeStr;
74     this.nullCount = nullCount;
75   }
76 
77   @Override
getArrayInterfaceJson()78   public String getArrayInterfaceJson() {
79     // There is no race-condition
80     if (arrayInterface == null) {
81       arrayInterface = CudfUtils.buildArrayInterface(this);
82     }
83     return arrayInterface;
84   }
85 
getDataPtr()86   public long getDataPtr() {
87     return dataPtr;
88   }
89 
getShape()90   public long getShape() {
91     return shape;
92   }
93 
getValidPtr()94   public long getValidPtr() {
95     return validPtr;
96   }
97 
getTypeSize()98   public int getTypeSize() {
99     return typeSize;
100   }
101 
getTypeStr()102   public String getTypeStr() {
103     return typeStr;
104   }
105 
getNullCount()106   public long getNullCount() {
107     return nullCount;
108   }
109 
110 }
111