1 /*
2 
3     Copyright (C) 2019 Alois Schloegl <alois.schloegl@ist.ac.at>
4     This file is part of the "BioSig for C/C++" repository
5     (biosig4c++) at http://biosig.sf.net/
6 
7     This program is free software; you can redistribute it and/or
8     modify it under the terms of the GNU General Public License
9     as published by the Free Software Foundation; either version 3
10     of the License, or (at your option) any later version.
11 
12 
13 References:
14     https://stackoverflow.com/questions/44378764/hello-tensorflow-using-the-c-api
15     https://stackoverflow.com/questions/41688217/how-to-load-a-graph-with-tensorflow-so-and-c-api-h-in-c-language
16     https://tebesu.github.io/posts/Training-a-TensorFlow-graph-in-C++-API
17  */
18 
19 
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <tensorflow/c/c_api.h>
24 #include "mex.h"
25 //#include "matrix.h"
26 
27 
28 #ifdef tmwtypes_h
29   #if (MX_API_VER<=0x07020000)
30     typedef int mwSize;
31   #endif
32 #endif
33 
34 
35 TF_Buffer* read_file(const char* file);
36 
free_buffer(void * data,size_t length)37 void free_buffer(void* data, size_t length) {
38         free(data);
39 }
40 
read_file(const char * file)41 TF_Buffer* read_file(const char* file) {
42   FILE *f = fopen(file, "rb");
43   if (f==NULL) return NULL;
44   fseek(f, 0, SEEK_END);
45   long fsize = ftell(f);
46   fseek(f, 0, SEEK_SET);  //same as rewind(f);
47 
48   void* data = malloc(fsize);
49   fread(data, fsize, 1, f);
50   fclose(f);
51 
52   TF_Buffer* buf = TF_NewBuffer();
53   buf->data      = data;
54   buf->length    = fsize;
55   buf->data_deallocator = free_buffer;
56   return buf;
57 }
58 
59 
deallocTFData(void * data,size_t len,void * arg)60 void deallocTFData(void* data, size_t len, void* arg) {
61 	return;
62 };
63 
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])64 void mexFunction(
65     int           nlhs,           /* number of expected outputs */
66     mxArray       *plhs[],        /* array of pointers to output arguments */
67     int           nrhs,           /* number of inputs */
68     const mxArray *prhs[]         /* array of pointers to input arguments */
69 )
70 
71 {
72 	const mxArray	*arg;
73 	TF_Buffer* graph_def = NULL;
74 	TF_Tensor * tensor   = NULL;
75 
76 	if (nrhs<1) {
77                 mexPrintf("mexTF (mexTensorflow) is in a very experimental state.\n");
78 		mexPrintf("   Usage of mexTF:\n");
79 		mexPrintf("\tv = mexTF()\n\t\treturns tensorflow version\n");
80 		mexPrintf("\t[v, graph_def] = mexTF('graph_def')\n\t\treads graph definition file\n");
81 		mexPrintf("\t[v, graph_def2] = mexTF(graph_def)\n\t\treads graph definition\n");
82 		mexPrintf("\t[v, graph_def2, class] = mexTF(graph_def, data)\n\t\treads graph definition\n");
83 		mexPrintf("   Input:\n");
84 		mexPrintf("   Output:\nTensorflow version\n");
85         }
86 
87         mexPrintf("%s line %d: %d %d\n",__FILE__,__LINE__,nrhs,nlhs);
88 	for (int k = 0; k < nrhs; k++) {
89 		arg = prhs[k];
90 		mxClassID argtype = mxGetClassID(arg);
91 
92 		if (mxIsEmpty(arg) && (k>0)) {
93 			mexPrintf("%s line %d\n",__FILE__,__LINE__);
94 		}
95 
96 		else if ( mxIsChar(arg) && (k==0) )  {
97 			mexPrintf("%s line %d\n",__FILE__,__LINE__);
98 			char *tmp = mxArrayToString(arg);
99 			graph_def = read_file(tmp);
100 			mxFree(tmp);
101 		}
102 
103 		else if ( ((argtype==mxINT8_CLASS) || (argtype==mxUINT8_CLASS)) && (k==0) )  {
104 			mexPrintf("%s line %d\n",__FILE__,__LINE__);
105 			if (!graph_def) {
106 				graph_def = TF_NewBuffer();
107 				graph_def->data   = mxGetData(arg);
108 				graph_def->length = mxGetNumberOfElements(arg);
109 				graph_def->data_deallocator = NULL;
110 			};
111 		}
112 
113 		else if ( (k==1) && mxIsNumeric(arg) )  {
114 			mexPrintf("%s line %d\n",__FILE__,__LINE__);
115 			TF_DataType tf_type;
116 
117 			mxClassID typ = mxGetClassID(arg);
118 			switch (argtype) {
119 			case mxDOUBLE_CLASS:
120 				tf_type = TF_DOUBLE;
121 				break;
122 			case mxSINGLE_CLASS:
123 				tf_type = TF_FLOAT;
124 				break;
125 
126 			case mxINT64_CLASS:
127 				tf_type = TF_INT64;
128 				break;
129 			case mxINT32_CLASS:
130 				tf_type = TF_INT32;
131 				break;
132 			case mxINT16_CLASS:
133 				tf_type = TF_INT16;
134 				break;
135 			case mxINT8_CLASS:
136 				tf_type = TF_INT8;
137 				break;
138 
139 			case mxUINT64_CLASS:
140 				tf_type = TF_UINT64;
141 				break;
142 			case mxUINT32_CLASS:
143 				tf_type = TF_UINT32;
144 				break;
145 			case mxUINT16_CLASS:
146 				tf_type = TF_UINT16;
147 				break;
148 			case mxUINT8_CLASS:
149 				tf_type = TF_UINT8;
150 				break;
151 
152 			default:
153 				mexPrintf("Error: data type %s of arg1 not supported\n",mxGetClassName(arg));
154 				return;
155 				;
156 			}
157 
158 			int ndims = mxGetNumberOfDimensions(arg);
159 			int64_t *dims = calloc(ndims, sizeof(int64_t));
160 			for (int k=0; k < ndims; k++) {
161 				dims[k] = *(mxGetDimensions(arg) + k);
162 				mexPrintf("%s line %d: dim[%d]= %d \n", __FILE__, __LINE__, k, dims[k]);
163 			}
164 
165 			mexPrintf("%s line %d: going to converted to tensor [%d,%d,%d] \n", __FILE__, __LINE__, ndims, mxGetNumberOfElements(arg), TF_DataTypeSize(tf_type));
166 
167 			tensor = TF_NewTensor( tf_type, dims, ndims, (void*)mxGetData(arg), mxGetNumberOfElements(arg) * TF_DataTypeSize(tf_type), &deallocTFData, NULL);
168 
169 			mexPrintf("%s line %d: input converted to tensor %p\n", __FILE__, __LINE__, tensor);
170 			mexPrintf("%s line %d: input converted to tensor %d %d %d %d \n", __FILE__, __LINE__, TF_NumDims(tensor), TF_TensorByteSize(tensor), TF_Dim(tensor, 0), TF_Dim(tensor, 1));
171 
172 			free(dims);
173 		}
174         }
175 
176 	mexPrintf("%s line %d\n",__FILE__,__LINE__);
177         plhs[0] = mxCreateString(TF_Version());
178 	if ( (nlhs > 1) && graph_def ) {
179 		mexPrintf("%s line %d\n",__FILE__,__LINE__);
180 		const int ndim = 2;
181 		mwSize dims[ndim];
182 		dims[0] = 1;
183 		dims[1] = graph_def->length;
184 	        plhs[1] = mxCreateNumericArray(ndim, dims,  mxUINT8_CLASS, mxREAL);
185 		void *p = mxMalloc(dims[1]);
186 		memcpy(p, graph_def->data, dims[1]);
187 		mxSetData(plhs[1], p);
188 	}
189 	mexPrintf("%s line %d\n",__FILE__,__LINE__);
190 
191 	/***********************************************
192 		load graph
193 	 ***********************************************/
194 	// Graph definition from unzipped https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
195 	// which is used in the Go, Java and Android examples
196 	//   TF_Buffer* graph_def = read_file("inception5h/tensorflow_inception_graph.pb");
197 	TF_Graph* graph = TF_NewGraph();
198 
199 	// Import graph_def into graph
200 	TF_Status* status = TF_NewStatus();
201 	TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
202 	TF_GraphImportGraphDef(graph, graph_def, opts, status);
203 	TF_DeleteImportGraphDefOptions(opts);
204 	TF_DeleteBuffer(graph_def);
205 
206 	if (TF_GetCode(status) != TF_OK) {
207 		fprintf(stderr, "ERROR: Unable to import graph <%s>\n", TF_Message(status));
208 		TF_DeleteStatus(status);
209 		return;
210 	}
211 	fprintf(stdout, "Successfully imported graph\n");
212 
213 
214 if (tensor==NULL) {
215 	// Use the graph
216 	TF_DeleteGraph(graph);
217 	return;
218 }
219 
220 	/***********************************************
221 		run session
222 	 ***********************************************/
223 	TF_SessionOptions * options = TF_NewSessionOptions();
224 	TF_Session * session = TF_NewSession( graph, options, status );
225 
226 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
227 
228 	char hello[] = "Hello TensorFlow!";
229 //	if (tensor==NULL) tensor = TF_AllocateTensor( TF_STRING, 0, 0, 8 + TF_StringEncodedSize( strlen( hello ) ) );
230 
231 	TF_Tensor * tensorOutput;
232 
233 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
234 
235 	TF_OperationDescription * operationDescription = TF_NewOperation( graph, "Const", "hello" );
236 
237 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
238 
239 	TF_Operation * operation;
240 	struct TF_Output output;
241 
242 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
243 
244 //	TF_StringEncode( hello, strlen( hello ), 8 + ( char * ) TF_TensorData( tensor ), TF_StringEncodedSize( strlen( hello ) ), status );
245 //	memset( TF_TensorData( tensor ), 0, 8 );
246 
247 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
248 
249 	TF_SetAttrTensor( operationDescription, "value", tensor, status );
250 
251 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
252 
253 	TF_SetAttrType( operationDescription, "dtype", TF_TensorType( tensor ) );
254 
255 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
256 
257 	operation = TF_FinishOperation( operationDescription, status );
258 
259 
260 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
261 
262 	output.oper = operation;
263 	output.index = 0;
264 
265 	TF_SessionRun( session, 0,
266                  0, 0, 0,  // Inputs
267                  &output, &tensorOutput, 1,  // Outputs
268                  &operation, 1,  // Operations
269                  0, status );
270 
271 
272 mexPrintf("%s line %d: %s\n",__FILE__,__LINE__, TF_Message(status));
273 
274 	printf( "status code: %i\n", TF_GetCode( status ) );
275 	printf( "%s\n", ( ( char * ) TF_TensorData( tensorOutput ) ) + 9 );
276 
277 	TF_CloseSession( session, status );
278 	TF_DeleteSession( session, status );
279 	TF_DeleteStatus( status );
280 	TF_DeleteSessionOptions( options );
281 	TF_DeleteGraph(graph);
282 
283 }
284 
285 
286