1 /*
2
3 Copyright (C) 2019 Alois Schloegl <alois.schloegl@ist.ac.at>
4 This file is part of the "BioSig for C/C++" repository
5 (biosig4c++) at http://biosig.sf.net/
6
7 This program is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public License
9 as published by the Free Software Foundation; either version 3
10 of the License, or (at your option) any later version.
11
12
13 References:
14 https://stackoverflow.com/questions/44378764/hello-tensorflow-using-the-c-api
15 https://stackoverflow.com/questions/41688217/how-to-load-a-graph-with-tensorflow-so-and-c-api-h-in-c-language
16 https://tebesu.github.io/posts/Training-a-TensorFlow-graph-in-C++-API
17 */
18
19
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <tensorflow/c/c_api.h>
24 #include "mex.h"
25 //#include "matrix.h"
26
27
28 #ifdef tmwtypes_h
29 #if (MX_API_VER<=0x07020000)
30 typedef int mwSize;
31 #endif
32 #endif
33
34
35 TF_Buffer* read_file(const char* file);
36
free_buffer(void * data,size_t length)37 void free_buffer(void* data, size_t length) {
38 free(data);
39 }
40
read_file(const char * file)41 TF_Buffer* read_file(const char* file) {
42 FILE *f = fopen(file, "rb");
43 if (f==NULL) return NULL;
44 fseek(f, 0, SEEK_END);
45 long fsize = ftell(f);
46 fseek(f, 0, SEEK_SET); //same as rewind(f);
47
48 void* data = malloc(fsize);
49 fread(data, fsize, 1, f);
50 fclose(f);
51
52 TF_Buffer* buf = TF_NewBuffer();
53 buf->data = data;
54 buf->length = fsize;
55 buf->data_deallocator = free_buffer;
56 return buf;
57 }
58
59
deallocTFData(void * data,size_t len,void * arg)60 void deallocTFData(void* data, size_t len, void* arg) {
61 return;
62 };
63
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])64 void mexFunction(
65 int nlhs, /* number of expected outputs */
66 mxArray *plhs[], /* array of pointers to output arguments */
67 int nrhs, /* number of inputs */
68 const mxArray *prhs[] /* array of pointers to input arguments */
69 )
70
71 {
72 const mxArray *arg;
73 TF_Buffer* graph_def = NULL;
74 TF_Tensor * tensor = NULL;
75
76 if (nrhs<1) {
77 mexPrintf("mexTF (mexTensorflow) is in a very experimental state.\n");
78 mexPrintf(" Usage of mexTF:\n");
79 mexPrintf("\tv = mexTF()\n\t\treturns tensorflow version\n");
80 mexPrintf("\t[v, graph_def] = mexTF('graph_def')\n\t\treads graph definition file\n");
81 mexPrintf("\t[v, graph_def2] = mexTF(graph_def)\n\t\treads graph definition\n");
82 mexPrintf("\t[v, graph_def2, class] = mexTF(graph_def, data)\n\t\treads graph definition\n");
83 mexPrintf(" Input:\n");
84 mexPrintf(" Output:\nTensorflow version\n");
85 }
86
87 mexPrintf("%s line %d: %d %d\n",__FILE__,__LINE__,nrhs,nlhs);
88 for (int k = 0; k < nrhs; k++) {
89 arg = prhs[k];
90 mxClassID argtype = mxGetClassID(arg);
91
92 if (mxIsEmpty(arg) && (k>0)) {
93 mexPrintf("%s line %d\n",__FILE__,__LINE__);
94 }
95
96 else if ( mxIsChar(arg) && (k==0) ) {
97 mexPrintf("%s line %d\n",__FILE__,__LINE__);
98 char *tmp = mxArrayToString(arg);
99 graph_def = read_file(tmp);
100 mxFree(tmp);
101 }
102
103 else if ( ((argtype==mxINT8_CLASS) || (argtype==mxUINT8_CLASS)) && (k==0) ) {
104 mexPrintf("%s line %d\n",__FILE__,__LINE__);
105 if (!graph_def) {
106 graph_def = TF_NewBuffer();
107 graph_def->data = mxGetData(arg);
108 graph_def->length = mxGetNumberOfElements(arg);
109 graph_def->data_deallocator = NULL;
110 };
111 }
112
113 else if ( (k==1) && mxIsNumeric(arg) ) {
114 mexPrintf("%s line %d\n",__FILE__,__LINE__);
115 TF_DataType tf_type;
116
117 mxClassID typ = mxGetClassID(arg);
118 switch (argtype) {
119 case mxDOUBLE_CLASS:
120 tf_type = TF_DOUBLE;
121 break;
122 case mxSINGLE_CLASS:
123 tf_type = TF_FLOAT;
124 break;
125
126 case mxINT64_CLASS:
127 tf_type = TF_INT64;
128 break;
129 case mxINT32_CLASS:
130 tf_type = TF_INT32;
131 break;
132 case mxINT16_CLASS:
133 tf_type = TF_INT16;
134 break;
135 case mxINT8_CLASS:
136 tf_type = TF_INT8;
137 break;
138
139 case mxUINT64_CLASS:
140 tf_type = TF_UINT64;
141 break;
142 case mxUINT32_CLASS:
143 tf_type = TF_UINT32;
144 break;
145 case mxUINT16_CLASS:
146 tf_type = TF_UINT16;
147 break;
148 case mxUINT8_CLASS:
149 tf_type = TF_UINT8;
150 break;
151
152 default:
153 mexPrintf("Error: data type %s of arg1 not supported\n",mxGetClassName(arg));
154 return;
155 ;
156 }
157
158 int ndims = mxGetNumberOfDimensions(arg);
159 int64_t *dims = calloc(ndims, sizeof(int64_t));
160 for (int k=0; k < ndims; k++) {
161 dims[k] = *(mxGetDimensions(arg) + k);
162 mexPrintf("%s line %d: dim[%d]= %d \n", __FILE__, __LINE__, k, dims[k]);
163 }
164
165 mexPrintf("%s line %d: going to converted to tensor [%d,%d,%d] \n", __FILE__, __LINE__, ndims, mxGetNumberOfElements(arg), TF_DataTypeSize(tf_type));
166
167 tensor = TF_NewTensor( tf_type, dims, ndims, (void*)mxGetData(arg), mxGetNumberOfElements(arg) * TF_DataTypeSize(tf_type), &deallocTFData, NULL);
168
169 mexPrintf("%s line %d: input converted to tensor %p\n", __FILE__, __LINE__, tensor);
170 mexPrintf("%s line %d: input converted to tensor %d %d %d %d \n", __FILE__, __LINE__, TF_NumDims(tensor), TF_TensorByteSize(tensor), TF_Dim(tensor, 0), TF_Dim(tensor, 1));
171
172 free(dims);
173 }
174 }
175
176 mexPrintf("%s line %d\n",__FILE__,__LINE__);
177 plhs[0] = mxCreateString(TF_Version());
178 if ( (nlhs > 1) && graph_def ) {
179 mexPrintf("%s line %d\n",__FILE__,__LINE__);
180 const int ndim = 2;
181 mwSize dims[ndim];
182 dims[0] = 1;
183 dims[1] = graph_def->length;
184 plhs[1] = mxCreateNumericArray(ndim, dims, mxUINT8_CLASS, mxREAL);
185 void *p = mxMalloc(dims[1]);
186 memcpy(p, graph_def->data, dims[1]);
187 mxSetData(plhs[1], p);
188 }
189 mexPrintf("%s line %d\n",__FILE__,__LINE__);
190
191 /***********************************************
192 load graph
193 ***********************************************/
194 // Graph definition from unzipped https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
195 // which is used in the Go, Java and Android examples
196 // TF_Buffer* graph_def = read_file("inception5h/tensorflow_inception_graph.pb");
197 TF_Graph* graph = TF_NewGraph();
198
199 // Import graph_def into graph
200 TF_Status* status = TF_NewStatus();
201 TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
202 TF_GraphImportGraphDef(graph, graph_def, opts, status);
203 TF_DeleteImportGraphDefOptions(opts);
204 TF_DeleteBuffer(graph_def);
205
206 if (TF_GetCode(status) != TF_OK) {
207 fprintf(stderr, "ERROR: Unable to import graph <%s>\n", TF_Message(status));
208 TF_DeleteStatus(status);
209 return;
210 }
211 fprintf(stdout, "Successfully imported graph\n");
212
213
214 if (tensor==NULL) {
215 // Use the graph
216 TF_DeleteGraph(graph);
217 return;
218 }
219
220 /***********************************************
221 run session
222 ***********************************************/
223 TF_SessionOptions * options = TF_NewSessionOptions();
224 TF_Session * session = TF_NewSession( graph, options, status );
225
226 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
227
228 char hello[] = "Hello TensorFlow!";
229 // if (tensor==NULL) tensor = TF_AllocateTensor( TF_STRING, 0, 0, 8 + TF_StringEncodedSize( strlen( hello ) ) );
230
231 TF_Tensor * tensorOutput;
232
233 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
234
235 TF_OperationDescription * operationDescription = TF_NewOperation( graph, "Const", "hello" );
236
237 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
238
239 TF_Operation * operation;
240 struct TF_Output output;
241
242 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
243
244 // TF_StringEncode( hello, strlen( hello ), 8 + ( char * ) TF_TensorData( tensor ), TF_StringEncodedSize( strlen( hello ) ), status );
245 // memset( TF_TensorData( tensor ), 0, 8 );
246
247 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
248
249 TF_SetAttrTensor( operationDescription, "value", tensor, status );
250
251 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
252
253 TF_SetAttrType( operationDescription, "dtype", TF_TensorType( tensor ) );
254
255 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
256
257 operation = TF_FinishOperation( operationDescription, status );
258
259
260 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
261
262 output.oper = operation;
263 output.index = 0;
264
265 TF_SessionRun( session, 0,
266 0, 0, 0, // Inputs
267 &output, &tensorOutput, 1, // Outputs
268 &operation, 1, // Operations
269 0, status );
270
271
272 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
273
274 printf( "status code: %i\n", TF_GetCode( status ) );
275 printf( "%s\n", ( ( char * ) TF_TensorData( tensorOutput ) ) + 9 );
276
277 TF_CloseSession( session, status );
278 TF_DeleteSession( session, status );
279 TF_DeleteStatus( status );
280 TF_DeleteSessionOptions( options );
281 TF_DeleteGraph(graph);
282
283 }
284
285
286