1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "permute.h"
16
17 namespace ncnn {
18
Permute()19 Permute::Permute()
20 {
21 one_blob_only = true;
22 support_inplace = false;
23 }
24
load_param(const ParamDict & pd)25 int Permute::load_param(const ParamDict& pd)
26 {
27 order_type = pd.get(0, 0);
28
29 return 0;
30 }
31
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const32 int Permute::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
33 {
34 int w = bottom_blob.w;
35 int h = bottom_blob.h;
36 int channels = bottom_blob.c;
37 size_t elemsize = bottom_blob.elemsize;
38
39 int dims = bottom_blob.dims;
40
41 if (dims == 2)
42 {
43 // order_type
44 // 0 = w h
45 // 1 = h w
46
47 if (order_type == 0)
48 {
49 top_blob = bottom_blob;
50 }
51 else if (order_type == 1)
52 {
53 top_blob.create(h, w, elemsize, opt.blob_allocator);
54 if (top_blob.empty())
55 return -100;
56
57 const float* ptr = bottom_blob;
58 float* outptr = top_blob;
59
60 for (int i = 0; i < w; i++)
61 {
62 for (int j = 0; j < h; j++)
63 {
64 outptr[i * h + j] = ptr[j * w + i];
65 }
66 }
67 }
68
69 return 0;
70 }
71
72 // order_type
73 // 0 = w h c
74 // 1 = h w c
75 // 2 = w c h
76 // 3 = c w h
77 // 4 = h c w
78 // 5 = c h w
79
80 if (order_type == 0)
81 {
82 top_blob = bottom_blob;
83 }
84 else if (order_type == 1)
85 {
86 top_blob.create(h, w, channels, elemsize, opt.blob_allocator);
87 if (top_blob.empty())
88 return -100;
89
90 #pragma omp parallel for num_threads(opt.num_threads)
91 for (int q = 0; q < channels; q++)
92 {
93 const float* ptr = bottom_blob.channel(q);
94 float* outptr = top_blob.channel(q);
95
96 for (int i = 0; i < w; i++)
97 {
98 for (int j = 0; j < h; j++)
99 {
100 outptr[i * h + j] = ptr[j * w + i];
101 }
102 }
103 }
104 }
105 else if (order_type == 2)
106 {
107 top_blob.create(w, channels, h, elemsize, opt.blob_allocator);
108 if (top_blob.empty())
109 return -100;
110
111 #pragma omp parallel for num_threads(opt.num_threads)
112 for (int q = 0; q < h; q++)
113 {
114 float* outptr = top_blob.channel(q);
115
116 for (int i = 0; i < channels; i++)
117 {
118 const float* ptr = bottom_blob.channel(i).row(q);
119
120 for (int j = 0; j < w; j++)
121 {
122 outptr[i * w + j] = ptr[j];
123 }
124 }
125 }
126 }
127 else if (order_type == 3)
128 {
129 top_blob.create(channels, w, h, elemsize, opt.blob_allocator);
130 if (top_blob.empty())
131 return -100;
132
133 #pragma omp parallel for num_threads(opt.num_threads)
134 for (int q = 0; q < h; q++)
135 {
136 float* outptr = top_blob.channel(q);
137
138 for (int i = 0; i < w; i++)
139 {
140 for (int j = 0; j < channels; j++)
141 {
142 const float* ptr = bottom_blob.channel(j).row(q);
143
144 outptr[i * channels + j] = ptr[i];
145 }
146 }
147 }
148 }
149 else if (order_type == 4)
150 {
151 top_blob.create(h, channels, w, elemsize, opt.blob_allocator);
152 if (top_blob.empty())
153 return -100;
154
155 #pragma omp parallel for num_threads(opt.num_threads)
156 for (int q = 0; q < w; q++)
157 {
158 float* outptr = top_blob.channel(q);
159
160 for (int i = 0; i < channels; i++)
161 {
162 const float* ptr = bottom_blob.channel(i);
163
164 for (int j = 0; j < h; j++)
165 {
166 outptr[i * h + j] = ptr[j * w + q];
167 }
168 }
169 }
170 }
171 else if (order_type == 5)
172 {
173 top_blob.create(channels, h, w, elemsize, opt.blob_allocator);
174 if (top_blob.empty())
175 return -100;
176
177 #pragma omp parallel for num_threads(opt.num_threads)
178 for (int q = 0; q < w; q++)
179 {
180 float* outptr = top_blob.channel(q);
181
182 for (int i = 0; i < h; i++)
183 {
184 for (int j = 0; j < channels; j++)
185 {
186 const float* ptr = bottom_blob.channel(j);
187
188 outptr[i * channels + j] = ptr[i * w + q];
189 }
190 }
191 }
192 }
193
194 return 0;
195 }
196
197 } // namespace ncnn
198