1 /* 2 Copyright (c) 2021 by Contributors 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package ml.dmlc.xgboost4j.gpu.java; 18 19 import java.util.stream.IntStream; 20 21 import ai.rapids.cudf.Table; 22 23 import ml.dmlc.xgboost4j.java.ColumnBatch; 24 25 /** 26 * Class to wrap CUDF Table to generate the cuda array interface. 27 */ 28 public class CudfColumnBatch extends ColumnBatch { 29 private final Table feature; 30 private final Table label; 31 private final Table weight; 32 private final Table baseMargin; 33 CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins)34 public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) { 35 this.feature = feature; 36 this.label = labels; 37 this.weight = weights; 38 this.baseMargin = baseMargins; 39 } 40 41 @Override getFeatureArrayInterface()42 public String getFeatureArrayInterface() { 43 return getArrayInterface(this.feature); 44 } 45 46 @Override getLabelsArrayInterface()47 public String getLabelsArrayInterface() { 48 return getArrayInterface(this.label); 49 } 50 51 @Override getWeightsArrayInterface()52 public String getWeightsArrayInterface() { 53 return getArrayInterface(this.weight); 54 } 55 56 @Override getBaseMarginsArrayInterface()57 public String getBaseMarginsArrayInterface() { 58 return getArrayInterface(this.baseMargin); 59 } 60 61 @Override close()62 public void close() { 63 if (feature != null) feature.close(); 64 if (label != null) label.close(); 65 if (weight != null) weight.close(); 66 if (baseMargin != null) baseMargin.close(); 67 } 68 getArrayInterface(Table table)69 private String getArrayInterface(Table table) { 70 if (table == null || table.getNumberOfColumns() == 0) { 71 return ""; 72 } 73 return CudfUtils.buildArrayInterface(getAsCudfColumn(table)); 74 } 75 getAsCudfColumn(Table table)76 private CudfColumn[] getAsCudfColumn(Table table) { 77 if (table == null || table.getNumberOfColumns() == 0) { 78 // This will never happen. 79 return new CudfColumn[]{}; 80 } 81 82 return IntStream.range(0, table.getNumberOfColumns()) 83 .mapToObj((i) -> table.getColumn(i)) 84 .map(CudfColumn::from) 85 .toArray(CudfColumn[]::new); 86 } 87 88 } 89