1 /*
2  Copyright (c) 2021 by Contributors
3 
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7 
8  http://www.apache.org/licenses/LICENSE-2.0
9 
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15  */
16 
17 package ml.dmlc.xgboost4j.gpu.java;
18 
19 import java.util.stream.IntStream;
20 
21 import ai.rapids.cudf.Table;
22 
23 import ml.dmlc.xgboost4j.java.ColumnBatch;
24 
25 /**
26  * Class to wrap CUDF Table to generate the cuda array interface.
27  */
28 public class CudfColumnBatch extends ColumnBatch {
29   private final Table feature;
30   private final Table label;
31   private final Table weight;
32   private final Table baseMargin;
33 
CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins)34   public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) {
35     this.feature = feature;
36     this.label = labels;
37     this.weight = weights;
38     this.baseMargin = baseMargins;
39   }
40 
41   @Override
getFeatureArrayInterface()42   public String getFeatureArrayInterface() {
43     return getArrayInterface(this.feature);
44   }
45 
46   @Override
getLabelsArrayInterface()47   public String getLabelsArrayInterface() {
48     return getArrayInterface(this.label);
49   }
50 
51   @Override
getWeightsArrayInterface()52   public String getWeightsArrayInterface() {
53     return getArrayInterface(this.weight);
54   }
55 
56   @Override
getBaseMarginsArrayInterface()57   public String getBaseMarginsArrayInterface() {
58     return getArrayInterface(this.baseMargin);
59   }
60 
61   @Override
close()62   public void close() {
63     if (feature != null) feature.close();
64     if (label != null) label.close();
65     if (weight != null) weight.close();
66     if (baseMargin != null) baseMargin.close();
67   }
68 
getArrayInterface(Table table)69   private String getArrayInterface(Table table) {
70     if (table == null || table.getNumberOfColumns() == 0) {
71       return "";
72     }
73     return CudfUtils.buildArrayInterface(getAsCudfColumn(table));
74   }
75 
getAsCudfColumn(Table table)76   private CudfColumn[] getAsCudfColumn(Table table) {
77     if (table == null || table.getNumberOfColumns() == 0) {
78       // This will never happen.
79       return new CudfColumn[]{};
80     }
81 
82     return IntStream.range(0, table.getNumberOfColumns())
83       .mapToObj((i) -> table.getColumn(i))
84       .map(CudfColumn::from)
85       .toArray(CudfColumn[]::new);
86   }
87 
88 }
89