关于Tensorflow、Pytorch、Ncnn中NCHW通道的总结

目前深度学习训练和推理涉及到的输入数据通常为4-D,对应的通道格式主要有两种:

  1. NCHW
  2. NHWC

其中各个字母代表的含义为:

  • N - Batch
  • C - Channel 特征图通道
  • H - Height 特征图高度
  • W - Width 特征图宽度

各个框架和图像处理方式对图像数据要求如下:

  • TensorFlow模型默认的输入格式为:RGB NHWC
  • Pytorch模型默认的输入格式为:RGB NCHW
  • ONNX模型默认的输入格式为:RGB NCHW fp32
  • Caffe 的Blob通道顺序是:NCHW
  • TensorRT中通道顺序:NCHW
  • OpenCV默认数据格式为:BGR HWC uint8

NCHW 则是 Nvidia cuDNN 默认格式,使用 GPU 加速时用 NCHW 格式速度会更快

一、基本原理

如图所示,假定N = 2,C = 16,H = 5,W = 4,
无论逻辑表达上是几维的数据,在计算机中存储时都是按照1D来存储的。下面很可以很清楚的看到NCHW和NHWC格式的高位数据,存储为1D时候的样子:
NCHWandNHWC

总的来说,无论是NCHW还是NHWC或者CHWN,在读取为1D时都是从后往前读,举例来说:

  • 对于NCHW格式的4D数据,首先取W方向数据;然后H方向;再C方向;最后N方向。
    所以,序列化出1D数据为:
    000 (W方向) 001 002 003,(H方向) 004 005 … 019,(C方向) 020 … 318 319,(N方向) 320 321 …

  • 对于NHWC格式的4D数据,首先取C方向数据;然后W方向;再H方向;最后N方向。
    所以,序列化出1D数据:
    000 (C方向) 020 … 300,(W方向) 001 021 … 303,(H方向) 004 … 319,(N方向) 320 340 …

我们通常在输入一张256 * 256分辨率的rgb图像时,对应的4D数据为[N = 1, H=256.h, W=256, C=3],然后对应的1D数据的组织方式如下图所示:
Rgb_NCHW_NHWC

NCHW: RRRRRRRRRRGGGGGGGGGGBBBBBBBBBB

NHWC: RBGRGBRGBRGBRGBRGBRGBRGBRGBRGB

二、java调用tensorflow pb模型推理的简单运用

第一种,若 Tensor.create(input) 输入的input是4维数组,那么按照tensorflow要求的NHWC的格式进行数据的组织即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
Imgproc.resize(src, dst, new Size(h, w)); // 1, h, w,3
float input[][][][] = new float[1][h][w][3];
System.out.println(dst.rows());
System.out.println(dst.cols());
for (int i = 0; i < dst.cols(); i++) {
for (int j = 0; j < dst.rows(); j++) {
double[] pixel = dst.get(j, i);
input[0][i][j][0] = (float) (255 - pixel[0]);
input[0][i][j][1] = (float) (255 - pixel[1]);
input[0][i][j][2] = (float) (255 - pixel[2]);
}
}
Tensor input_X = Tensor.create(input);

第二种,若 Tensor.create(input) 输入的input是1维数组,那么按照前面转1D数据的基本原理,将NHWC的格式进行转换后再组织即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Imgproc.resize(src, dst, new Size(h, w)); // 1, h, w,3
float input[] = new float[1 * h * w * 3];
System.out.println(dst.rows());
System.out.println(dst.cols());
int index = 0
for (int i = 0; i < dst.cols(); i++) {
for (int j = 0; j < dst.rows(); j++) {
double[] pixel = dst.get(j, i);
input[index++] = (float) (255 - pixel[0]);
input[index++] = (float) (255 - pixel[1]);
input[index++] = (float) (255 - pixel[2]);
}
}
Tensor input_X = Tensor.create(shape = (1,h,w,3),input);

您的支持将鼓励我继续创作!