• Home
  • History
  • Annotate
Name Date Size #Lines LOC

..03-May-2022-

src/H03-May-2022-33,10921,827

tests/H03-May-2022-17,2599,474

README.mdH A D07-Dec-202126.8 KiB625511

README.md

1# PNNX
2PyTorch Neural Network eXchange(PNNX) is an open standard for PyTorch model interoperability. PNNX provides an open model format for PyTorch. It defines computation graph as well as high level operators strictly matches PyTorch.
3
4# Rationale
5PyTorch is currently one of the most popular machine learning frameworks. We need to deploy the trained AI model to various hardware and environments more conveniently and easily.
6
7Before PNNX, we had the following methods:
8
91. export to ONNX, and deploy with ONNX-runtime
102. export to ONNX, and convert onnx to inference-framework specific format, and deploy with TensorRT/OpenVINO/ncnn/etc.
113. export to TorchScript, and deploy with libtorch
12
13As far as we know, ONNX has the ability to express the PyTorch model and it is an open standard. People usually use ONNX as an intermediate  representation between PyTorch and the inference platform. However, ONNX still has the following fatal problems, which makes the birth of PNNX necessary:
14
151. ONNX does not have a human-readable and editable file representation, making it difficult for users to easily modify the computation graph or add custom operators.
162. The operator definition of ONNX is not completely in accordance with PyTorch. When exporting some PyTorch operators, glue operators are often added passively by ONNX, which makes the computation graph inconsistent with PyTorch and may impact the inference efficiency.
173. There are a large number of additional parameters designed to be compatible with various ML frameworks in the operator definition in ONNX. These parameters increase the burden of inference implementation on hardware and software.
18
19PNNX tries to define a set of operators and a simple and easy-to-use format that are completely contrasted with the python api of PyTorch, so that the conversion and interoperability of PyTorch models are more convenient.
20
21# Features
22
231. [Human readable and editable format](#the-pnnxparam-format)
242. [Plain model binary in storage zip](#the-pnnxbin-format)
253. [One-to-one mapping of PNNX operators and PyTorch python api](#pnnx-operator)
264. [Preserve math expression as one operator](#pnnx-expression-operator)
275. [Preserve torch function as one operator](#pnnx-torch-function-operator)
286. [Preserve miscellaneous module as one operator](#pnnx-module-operator)
297. [Inference via exported PyTorch python code](#pnnx-python-inference)
308. [Tensor shape propagation](#pnnx-shape-propagation)
319. [Model optimization](#pnnx-model-optimization)
3210. [Custom operator support](#pnnx-custom-operator)
33
34# Build TorchScript to PNNX converter
35
361. Install PyTorch and TorchVision c++ library
372. Build PNNX with cmake
38
39# Usage
40
411. Export your model to TorchScript
42
43```python
44import torch
45import torchvision.models as models
46
47net = models.resnet18(pretrained=True)
48net = net.eval()
49
50x = torch.rand(1, 3, 224, 224)
51
52mod = torch.jit.trace(net, x)
53torch.jit.save(mod, "resnet18.pt")
54```
55
562. Convert TorchScript to PNNX
57
58```shell
59pnnx resnet18.pt inputshape=[1,3,224,224]
60```
61
62Normally, you will get six files
63
64```resnet18.pnnx.param``` PNNX graph definition
65
66```resnet18.pnnx.bin``` PNNX model weight
67
68```resnet18_pnnx.py``` PyTorch script for inference, the python code for model construction and weight initialization
69
70```resnet18.ncnn.param``` ncnn graph definition
71
72```resnet18.ncnn.bin``` ncnn model weight
73
74```resnet18_ncnn.py``` pyncnn script for inference
75
763. Visualize PNNX with Netron
77
78Open https://netron.app/ in browser, and drag resnet18.pnnx.param into it.
79
804. PNNX command line options
81
82```
83Usage: pnnx [model.pt] [(key=value)...]
84  pnnxparam=model.pnnx.param
85  pnnxbin=model.pnnx.bin
86  pnnxpy=model_pnnx.py
87  ncnnparam=model.ncnn.param
88  ncnnbin=model.ncnn.bin
89  optlevel=2
90  device=cpu/gpu
91  inputshape=[1,3,224,224],...
92  inputshape2=[1,3,320,320],...
93  customop=/home/nihui/.cache/torch_extensions/fused/fused.so,...
94  moduleop=models.common.Focus,models.yolo.Detect,...
95Sample usage: pnnx mobilenet_v2.pt inputshape=[1,3,224,224]
96              pnnx yolov5s.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] device=gpu moduleop=models.common.Focus,models.yolo.Detect
97```
98
99# The pnnx.param format
100### example
101```
1027767517
1034 3
104pnnx.Input      input       0 1 0
105nn.Conv2d       conv_0      1 1 0 1 bias=1 dilation=(1,1) groups=1 in_channels=12 kernel_size=(3,3) out_channels=16 padding=(0,0) stride=(1,1) @bias=(16)f32 @weight=(16,12,3,3)f32
106nn.Conv2d       conv_1      1 1 1 2 bias=1 dilation=(1,1) groups=1 in_channels=16 kernel_size=(2,2) out_channels=20 padding=(2,2) stride=(2,2) @bias=(20)f32 @weight=(20,16,2,2)f32
107pnnx.Output     output      1 0 2
108```
109### overview
110```
111[magic]
112```
113* magic number : 7767517
114```
115[operator count] [operand count]
116```
117* operator count : count of the operator line follows
118* operand count : count of all operands
119### operator line
120```
121[type] [name] [input count] [output count] [input operands] [output operands] [operator params]
122```
123* type : type name, such as Conv2d ReLU etc
124* name : name of this operator
125* input count : count of the operands this operator needs as input
126* output count : count of the operands this operator produces as output
127* input operands : name list of all the input blob names, separated by space
128* output operands : name list of all the output blob names, separated by space
129* operator params : key=value pair list, separated by space, operator weights are prefixed by ```@``` symbol, tensor shapes are prefixed by ```#``` symbol, input parameter keys are prefixed by ```$```
130
131# The pnnx.bin format
132
133pnnx.bin file is a zip file with store-only mode(no compression)
134
135weight binary file has its name composed by operator name and weight name
136
137For example, ```nn.Conv2d       conv_0      1 1 0 1 bias=1 dilation=(1,1) groups=1 in_channels=12 kernel_size=(3,3) out_channels=16 padding=(0,0) stride=(1,1) @bias=(16) @weight=(16,12,3,3)``` would pull conv_0.weight and conv_0.bias into pnnx.bin zip archive.
138
139weight binaries can be listed or modified with any archive application eg. 7zip
140
141![pnnx.bin](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/pnnx.bin.png)
142
143# PNNX operator
144PNNX always preserve operators from what PyTorch python api provides.
145
146Here is the netron visualization comparision among ONNX, TorchScript and PNNX with the original PyTorch python code shown.
147
148```python
149import torch
150import torch.nn as nn
151
152class Model(nn.Module):
153    def __init__(self):
154        super(Model, self).__init__()
155
156        self.attention = nn.MultiheadAttention(embed_dim=256, num_heads=32)
157
158    def forward(self, x):
159        x, _ = self.attention(x, x, x)
160        return x
161```
162
163|ONNX|TorchScript|PNNX|
164|----|---|---|
165|![MultiheadAttention.onnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/MultiheadAttention.onnx.png)|![MultiheadAttention.pt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/MultiheadAttention.pt.png)|![MultiheadAttention.pnnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/MultiheadAttention.pnnx.png)|
166
167# PNNX expression operator
168PNNX trys to preserve expression from what PyTorch python code writes.
169
170Here is the netron visualization comparision among ONNX, TorchScript and PNNX with the original PyTorch python code shown.
171
172```python
173import torch
174
175def foo(x, y):
176    return torch.sqrt((2 * x + y) / 12)
177```
178
179|ONNX|TorchScript|PNNX|
180|---|---|---|
181|![math.onnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/math.onnx.png)|![math.pt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/math.pt.png)|![math.pnnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/math.pnnx.png)|
182
183# PNNX torch function operator
184PNNX trys to preserve torch functions and Tensor member functions as one operator from what PyTorch python api provides.
185
186Here is the netron visualization comparision among ONNX, TorchScript and PNNX with the original PyTorch python code shown.
187
188```python
189import torch
190import torch.nn.functional as F
191
192class Model(nn.Module):
193    def __init__(self):
194        super(Model, self).__init__()
195
196    def forward(self, x):
197        x = F.normalize(x, eps=1e-3)
198        return x
199```
200
201|ONNX|TorchScript|PNNX|
202|---|---|---|
203|![function.onnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/function.onnx.png)|![function.pt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/function.pt.png)|![function.pnnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/function.pnnx.png)|
204
205
206# PNNX module operator
207Users could ask PNNX to keep module as one big operator when it has complex logic.
208
209The process is optional and could be enabled via moduleop command line option.
210
211After pass_level0, all modules will be presented in terminal output, then you can pick the intersting ones as module operators.
212```
213############# pass_level0
214inline module = models.common.Bottleneck
215inline module = models.common.C3
216inline module = models.common.Concat
217inline module = models.common.Conv
218inline module = models.common.Focus
219inline module = models.common.SPP
220inline module = models.yolo.Detect
221inline module = utils.activations.SiLU
222```
223
224```bash
225pnnx yolov5s.pt inputshape=[1,3,640,640] moduleop=models.common.Focus,models.yolo.Detect
226```
227
228Here is the netron visualization comparision among ONNX, TorchScript and PNNX with the original PyTorch python code shown.
229
230```python
231import torch
232import torch.nn as nn
233
234class Focus(nn.Module):
235    # Focus wh information into c-space
236    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
237        super().__init__()
238        self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
239
240    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
241        return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
242```
243
244|ONNX|TorchScript|PNNX|PNNX with module operator|
245|---|---|---|---|
246|![focus.onnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/focus.onnx.png)|![focus.pt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/focus.pt.png)|![focus.pnnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/focus.pnnx.png)|![focus.pnnx2](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/focus.pnnx2.png)|
247
248
249# PNNX python inference
250
251A python script will be generated by default when converting torchscript to pnnx.
252
253This script is the python code representation of PNNX and can be used for model inference.
254
255There are some utility functions for loading weight binary from pnnx.bin.
256
257You can even export the model torchscript AGAIN from this generated code!
258
259```python
260import torch
261import torch.nn as nn
262import torch.nn.functional as F
263
264class Model(nn.Module):
265    def __init__(self):
266        super(Model, self).__init__()
267
268        self.linear_0 = nn.Linear(in_features=128, out_features=256, bias=True)
269        self.linear_1 = nn.Linear(in_features=256, out_features=4, bias=True)
270
271    def forward(self, x):
272        x = self.linear_0(x)
273        x = F.leaky_relu(x, 0.15)
274        x = self.linear_1(x)
275        return x
276```
277
278```python
279import os
280import numpy as np
281import tempfile, zipfile
282import torch
283import torch.nn as nn
284import torch.nn.functional as F
285
286class Model(nn.Module):
287    def __init__(self):
288        super(Model, self).__init__()
289
290        self.linear_0 = nn.Linear(bias=True, in_features=128, out_features=256)
291        self.linear_1 = nn.Linear(bias=True, in_features=256, out_features=4)
292
293        archive = zipfile.ZipFile('../../function.pnnx.bin', 'r')
294        self.linear_0.bias = self.load_pnnx_bin_as_parameter(archive, 'linear_0.bias', (256), 'float32')
295        self.linear_0.weight = self.load_pnnx_bin_as_parameter(archive, 'linear_0.weight', (256,128), 'float32')
296        self.linear_1.bias = self.load_pnnx_bin_as_parameter(archive, 'linear_1.bias', (4), 'float32')
297        self.linear_1.weight = self.load_pnnx_bin_as_parameter(archive, 'linear_1.weight', (4,256), 'float32')
298        archive.close()
299
300    def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype):
301        return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype))
302
303    def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):
304        _, tmppath = tempfile.mkstemp()
305        tmpf = open(tmppath, 'wb')
306        with archive.open(key) as keyfile:
307            tmpf.write(keyfile.read())
308        tmpf.close()
309        m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()
310        os.remove(tmppath)
311        return torch.from_numpy(m)
312
313    def forward(self, v_x_1):
314        v_7 = self.linear_0(v_x_1)
315        v_input_1 = F.leaky_relu(input=v_7, negative_slope=0.150000)
316        v_12 = self.linear_1(v_input_1)
317        return v_12
318```
319
320# PNNX shape propagation
321Users could ask PNNX to resolve all tensor shapes in model graph and constify some common expressions involved when tensor shapes are known.
322
323The process is optional and could be enabled via inputshape command line option.
324
325```bash
326pnnx shufflenet_v2_x1_0.pt inputshape=[1,3,224,224]
327```
328
329```python
330def channel_shuffle(x: Tensor, groups: int) -> Tensor:
331    batchsize, num_channels, height, width = x.size()
332    channels_per_group = num_channels // groups
333
334    # reshape
335    x = x.view(batchsize, groups, channels_per_group, height, width)
336
337    x = torch.transpose(x, 1, 2).contiguous()
338
339    # flatten
340    x = x.view(batchsize, -1, height, width)
341
342    return x
343```
344
345|without shape propagation|with shape propagation|
346|---|---|
347|![noshapeinfer](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/noshapeinfer.png)|![shapeinfer](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/shapeinfer.pnnx.png)|
348
349
350# PNNX model optimization
351
352|ONNX|TorchScript|PNNX without optimization|PNNX with optimization|
353|---|---|---|---|
354|![optlessonnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/optless.onnx.png)|![optlesspt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/optless.pt.png)|![optless](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/optless.pnnx.png)|![opt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/opt.pnnx.png)|
355
356
357# PNNX custom operator
358
359```python
360import os
361
362import torch
363from torch.autograd import Function
364from torch.utils.cpp_extension import load, _import_module_from_library
365
366module_path = os.path.dirname(__file__)
367upfirdn2d_op = load(
368    'upfirdn2d',
369    sources=[
370        os.path.join(module_path, 'upfirdn2d.cpp'),
371        os.path.join(module_path, 'upfirdn2d_kernel.cu'),
372    ],
373    is_python_module=False
374)
375
376def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
377    pad_x0 = pad[0]
378    pad_x1 = pad[1]
379    pad_y0 = pad[0]
380    pad_y1 = pad[1]
381
382    kernel_h, kernel_w = kernel.shape
383    batch, channel, in_h, in_w = input.shape
384
385    input = input.reshape(-1, in_h, in_w, 1)
386
387    out_h = (in_h * up + pad_y0 + pad_y1 - kernel_h) // down + 1
388    out_w = (in_w * up + pad_x0 + pad_x1 - kernel_w) // down + 1
389
390    out = torch.ops.upfirdn2d_op.upfirdn2d(input, kernel, up, up, down, down, pad_x0, pad_x1, pad_y0, pad_y1)
391
392    out = out.view(-1, channel, out_h, out_w)
393
394    return out
395```
396
397```cpp
398#include <torch/extension.h>
399
400torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
401                        int64_t up_x, int64_t up_y, int64_t down_x, int64_t down_y,
402                        int64_t pad_x0, int64_t pad_x1, int64_t pad_y0, int64_t pad_y1) {
403    // operator body
404}
405
406TORCH_LIBRARY(upfirdn2d_op, m) {
407    m.def("upfirdn2d", upfirdn2d);
408}
409```
410
411<img src="https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/customop.pnnx.png" width="400" />
412
413# Supported PyTorch operator status
414
415| torch.nn        | Is Supported | Export to ncnn |
416|---------------------------|----|---|
417|nn.AdaptiveAvgPool1d       | :heavy_check_mark: | :heavy_check_mark: |
418|nn.AdaptiveAvgPool2d       | :heavy_check_mark: | :heavy_check_mark: |
419|nn.AdaptiveAvgPool3d       | :heavy_check_mark: | :heavy_check_mark: |
420|nn.AdaptiveMaxPool1d       | :heavy_check_mark: | :heavy_check_mark: |
421|nn.AdaptiveMaxPool2d       | :heavy_check_mark: | :heavy_check_mark: |
422|nn.AdaptiveMaxPool3d       | :heavy_check_mark: | :heavy_check_mark: |
423|nn.AlphaDropout            |   |
424|nn.AvgPool1d               | :heavy_check_mark: | :heavy_check_mark:* |
425|nn.AvgPool2d               | :heavy_check_mark: | :heavy_check_mark:* |
426|nn.AvgPool3d               | :heavy_check_mark: | :heavy_check_mark:* |
427|nn.BatchNorm1d             | :heavy_check_mark: | :heavy_check_mark: |
428|nn.BatchNorm2d             | :heavy_check_mark: | :heavy_check_mark: |
429|nn.BatchNorm3d             | :heavy_check_mark: | :heavy_check_mark: |
430|nn.Bilinear                |   |
431|nn.CELU                    | :heavy_check_mark: |
432|nn.ChannelShuffle          | :heavy_check_mark: | :heavy_check_mark: |
433|nn.ConstantPad1d           | :heavy_check_mark: | :heavy_check_mark: |
434|nn.ConstantPad2d           | :heavy_check_mark: | :heavy_check_mark: |
435|nn.ConstantPad3d           | :heavy_check_mark: | :heavy_check_mark: |
436|nn.Conv1d                  | :heavy_check_mark: | :heavy_check_mark: |
437|nn.Conv2d                  | :heavy_check_mark: | :heavy_check_mark: |
438|nn.Conv3d                  | :heavy_check_mark: | :heavy_check_mark: |
439|nn.ConvTranspose1d         | :heavy_check_mark: |
440|nn.ConvTranspose2d         | :heavy_check_mark: | :heavy_check_mark: |
441|nn.ConvTranspose3d         | :heavy_check_mark: |
442|nn.CosineSimilarity        |   |
443|nn.Dropout                 |   | :heavy_check_mark:* |
444|nn.Dropout2d               |   |
445|nn.Dropout3d               |   |
446|nn.ELU                     | :heavy_check_mark: | :heavy_check_mark: |
447|nn.Embedding               | :heavy_check_mark: | :heavy_check_mark: |
448|nn.EmbeddingBag            |   |
449|nn.Flatten                 | :heavy_check_mark: |
450|nn.Fold                    |   |
451|nn.FractionalMaxPool2d     |   |
452|nn.FractionalMaxPool3d     |   |
453|nn.GELU                    | :heavy_check_mark: | :heavy_check_mark: |
454|nn.GroupNorm               | :heavy_check_mark: | :heavy_check_mark: |
455|nn.GRU                     | :heavy_check_mark: | :heavy_check_mark: |
456|nn.GRUCell                 |   |
457|nn.Hardshrink              | :heavy_check_mark: |
458|nn.Hardsigmoid             | :heavy_check_mark: | :heavy_check_mark: |
459|nn.Hardswish               | :heavy_check_mark: | :heavy_check_mark: |
460|nn.Hardtanh                | :heavy_check_mark: | :heavy_check_mark: |
461|nn.Identity                |   |
462|nn.InstanceNorm1d          | :heavy_check_mark: |
463|nn.InstanceNorm2d          | :heavy_check_mark: | :heavy_check_mark: |
464|nn.InstanceNorm3d          | :heavy_check_mark: |
465|nn.LayerNorm               | :heavy_check_mark: | :heavy_check_mark: |
466|nn.LazyBatchNorm1d         |   |
467|nn.LazyBatchNorm2d         |   |
468|nn.LazyBatchNorm3d         |   |
469|nn.LazyConv1d              |   |
470|nn.LazyConv2d              |   |
471|nn.LazyConv3d              |   |
472|nn.LazyConvTranspose1d     |   |
473|nn.LazyConvTranspose2d     |   |
474|nn.LazyConvTranspose3d     |   |
475|nn.LazyLinear              |   |
476|nn.LeakyReLU               | :heavy_check_mark: | :heavy_check_mark: |
477|nn.Linear                  | :heavy_check_mark: | :heavy_check_mark: |
478|nn.LocalResponseNorm       | :heavy_check_mark: | :heavy_check_mark: |
479|nn.LogSigmoid              | :heavy_check_mark: |
480|nn.LogSoftmax              | :heavy_check_mark: |
481|nn.LPPool1d                | :heavy_check_mark: |
482|nn.LPPool2d                | :heavy_check_mark: |
483|nn.LSTM                    | :heavy_check_mark: | :heavy_check_mark: |
484|nn.LSTMCell                |   |
485|nn.MaxPool1d               | :heavy_check_mark: | :heavy_check_mark: |
486|nn.MaxPool2d               | :heavy_check_mark: | :heavy_check_mark: |
487|nn.MaxPool3d               | :heavy_check_mark: | :heavy_check_mark: |
488|nn.MaxUnpool1d             |   |
489|nn.MaxUnpool2d             |   |
490|nn.MaxUnpool3d             |   |
491|nn.Mish                    | :heavy_check_mark: | :heavy_check_mark: |
492|nn.MultiheadAttention      | :heavy_check_mark: | :heavy_check_mark:* |
493|nn.PairwiseDistance        |   |
494|nn.PixelShuffle            | :heavy_check_mark: | :heavy_check_mark: |
495|nn.PixelUnshuffle          | :heavy_check_mark: | :heavy_check_mark: |
496|nn.PReLU                   | :heavy_check_mark: | :heavy_check_mark: |
497|nn.ReflectionPad1d         | :heavy_check_mark: | :heavy_check_mark: |
498|nn.ReflectionPad2d         | :heavy_check_mark: | :heavy_check_mark: |
499|nn.ReLU                    | :heavy_check_mark: | :heavy_check_mark: |
500|nn.ReLU6                   | :heavy_check_mark: | :heavy_check_mark: |
501|nn.ReplicationPad1d        | :heavy_check_mark: | :heavy_check_mark: |
502|nn.ReplicationPad2d        | :heavy_check_mark: | :heavy_check_mark: |
503|nn.ReplicationPad3d        | :heavy_check_mark: |
504|nn.RNN                     | :heavy_check_mark: | :heavy_check_mark:* |
505|nn.RNNBase                 |   |
506|nn.RNNCell                 |   |
507|nn.RReLU                   | :heavy_check_mark: |
508|nn.SELU                    | :heavy_check_mark: | :heavy_check_mark: |
509|nn.Sigmoid                 | :heavy_check_mark: | :heavy_check_mark: |
510|nn.SiLU                    | :heavy_check_mark: | :heavy_check_mark: |
511|nn.Softmax                 | :heavy_check_mark: | :heavy_check_mark: |
512|nn.Softmax2d               |   |
513|nn.Softmin                 | :heavy_check_mark: |
514|nn.Softplus                | :heavy_check_mark: |
515|nn.Softshrink              | :heavy_check_mark: |
516|nn.Softsign                | :heavy_check_mark: |
517|nn.SyncBatchNorm           |   |
518|nn.Tanh                    | :heavy_check_mark: | :heavy_check_mark: |
519|nn.Tanhshrink              | :heavy_check_mark: |
520|nn.Threshold               | :heavy_check_mark: |
521|nn.Transformer             |   |
522|nn.TransformerDecoder      |   |
523|nn.TransformerDecoderLayer |   |
524|nn.TransformerEncoder      |   |
525|nn.TransformerEncoderLayer |   |
526|nn.Unflatten               |   |
527|nn.Unfold                  |   |
528|nn.Upsample                | :heavy_check_mark: | :heavy_check_mark: |
529|nn.UpsamplingBilinear2d    | :heavy_check_mark: | :heavy_check_mark: |
530|nn.UpsamplingNearest2d     | :heavy_check_mark: | :heavy_check_mark: |
531|nn.ZeroPad2d               | :heavy_check_mark: | :heavy_check_mark: |
532
533
534| torch.nn.functional | Is Supported | Export to ncnn |
535|---------------------------|----|----|
536|F.adaptive_avg_pool1d      | :heavy_check_mark: | :heavy_check_mark: |
537|F.adaptive_avg_pool2d      | :heavy_check_mark: | :heavy_check_mark: |
538|F.adaptive_avg_pool3d      | :heavy_check_mark: | :heavy_check_mark: |
539|F.adaptive_max_pool1d      | :heavy_check_mark: | :heavy_check_mark: |
540|F.adaptive_max_pool2d      | :heavy_check_mark: | :heavy_check_mark: |
541|F.adaptive_max_pool3d      | :heavy_check_mark: | :heavy_check_mark: |
542|F.affine_grid              | :heavy_check_mark: | :heavy_check_mark: |
543|F.alpha_dropout            |  |
544|F.avg_pool1d               | :heavy_check_mark: | :heavy_check_mark:* |
545|F.avg_pool2d               | :heavy_check_mark: | :heavy_check_mark:* |
546|F.avg_pool3d               | :heavy_check_mark: | :heavy_check_mark:* |
547|F.batch_norm               | :heavy_check_mark: | :heavy_check_mark: |
548|F.bilinear                 |  |
549|F.celu                     | :heavy_check_mark: |
550|F.conv1d                   | :heavy_check_mark: |
551|F.conv2d                   | :heavy_check_mark: | :heavy_check_mark:* |
552|F.conv3d                   | :heavy_check_mark: |
553|F.conv_transpose1d         | :heavy_check_mark: |
554|F.conv_transpose2d         | :heavy_check_mark: |
555|F.conv_transpose3d         | :heavy_check_mark: |
556|F.cosine_similarity        |  |
557|F.dropout                  |  |
558|F.dropout2d                |  |
559|F.dropout3d                |  |
560|F.elu                      | :heavy_check_mark: | :heavy_check_mark: |
561|F.elu_                     | :heavy_check_mark: | :heavy_check_mark: |
562|F.embedding                |  |
563|F.embedding_bag            |  |
564|F.feature_alpha_dropout    |  |
565|F.fold                     |  |
566|F.fractional_max_pool2d    |  |
567|F.fractional_max_pool3d    |  |
568|F.gelu                     | :heavy_check_mark: | :heavy_check_mark: |
569|F.glu                      |  |
570|F.grid_sample              | :heavy_check_mark: |
571|F.group_norm               | :heavy_check_mark: | :heavy_check_mark: |
572|F.gumbel_softmax           |  |
573|F.hardshrink               | :heavy_check_mark: |
574|F.hardsigmoid              | :heavy_check_mark: | :heavy_check_mark: |
575|F.hardswish                | :heavy_check_mark: | :heavy_check_mark: |
576|F.hardtanh                 | :heavy_check_mark: | :heavy_check_mark: |
577|F.hardtanh_                | :heavy_check_mark: | :heavy_check_mark: |
578|F.instance_norm            | :heavy_check_mark: | :heavy_check_mark: |
579|F.interpolate              | :heavy_check_mark: | :heavy_check_mark: |
580|F.layer_norm               | :heavy_check_mark: | :heavy_check_mark: |
581|F.leaky_relu               | :heavy_check_mark: | :heavy_check_mark: |
582|F.leaky_relu_              | :heavy_check_mark: | :heavy_check_mark: |
583|F.linear                   | :heavy_check_mark: | :heavy_check_mark:* |
584|F.local_response_norm      | :heavy_check_mark: | :heavy_check_mark: |
585|F.logsigmoid               | :heavy_check_mark: |
586|F.log_softmax              | :heavy_check_mark: |
587|F.lp_pool1d                | :heavy_check_mark: |
588|F.lp_pool2d                | :heavy_check_mark: |
589|F.max_pool1d               | :heavy_check_mark: | :heavy_check_mark: |
590|F.max_pool2d               | :heavy_check_mark: | :heavy_check_mark: |
591|F.max_pool3d               | :heavy_check_mark: | :heavy_check_mark: |
592|F.max_unpool1d             |  |
593|F.max_unpool2d             |  |
594|F.max_unpool3d             |  |
595|F.mish                     | :heavy_check_mark: | :heavy_check_mark: |
596|F.normalize                | :heavy_check_mark: | :heavy_check_mark: |
597|F.one_hot                  |  |
598|F.pad                      | :heavy_check_mark: | :heavy_check_mark: |
599|F.pairwise_distance        |  |
600|F.pdist                    |  |
601|F.pixel_shuffle            | :heavy_check_mark: | :heavy_check_mark: |
602|F.pixel_unshuffle          | :heavy_check_mark: | :heavy_check_mark: |
603|F.prelu                    | :heavy_check_mark: | :heavy_check_mark: |
604|F.relu                     | :heavy_check_mark: | :heavy_check_mark: |
605|F.relu_                    | :heavy_check_mark: | :heavy_check_mark: |
606|F.relu6                    | :heavy_check_mark: | :heavy_check_mark: |
607|F.rrelu                    | :heavy_check_mark: |
608|F.rrelu_                   | :heavy_check_mark: |
609|F.selu                     | :heavy_check_mark: | :heavy_check_mark: |
610|F.sigmoid                  | :heavy_check_mark: | :heavy_check_mark: |
611|F.silu                     | :heavy_check_mark: | :heavy_check_mark: |
612|F.softmax                  | :heavy_check_mark: | :heavy_check_mark: |
613|F.softmin                  | :heavy_check_mark: |
614|F.softplus                 | :heavy_check_mark: |
615|F.softshrink               | :heavy_check_mark: |
616|F.softsign                 | :heavy_check_mark: |
617|F.tanh                     | :heavy_check_mark: | :heavy_check_mark: |
618|F.tanhshrink               | :heavy_check_mark: |
619|F.threshold                | :heavy_check_mark: |
620|F.threshold_               | :heavy_check_mark: |
621|F.unfold                   |  |
622|F.upsample                 | :heavy_check_mark: | :heavy_check_mark: |
623|F.upsample_bilinear        | :heavy_check_mark: | :heavy_check_mark: |
624|F.upsample_nearest         | :heavy_check_mark: | :heavy_check_mark: |
625