基于Permutohedral Lattice 的Bilateral filter 源码及部分注释【C++】

举报
ShaderJoy 发表于 2021/12/30 01:01:48 2021/12/30
【摘要】 基于Permutohedral Lattice 的Bilateral filter 源码及部分注释【来自于网络】 实现基于论文《Fast High-Dimensional Filtering Using the Permutohedral Lattice》 . 延伸阅读 saliency filte...


基于Permutohedral Lattice 的Bilateral filter 源码及部分注释【来自于网络】


实现基于论文《Fast High-Dimensional Filtering Using the Permutohedral Lattice》 .

延伸阅读 saliency filters精读之permutohedral lattice 

1.bilateralPermutohedral 方法:


  
  1. static Mat bilateralPermutohedral(Mat img, Mat edge, float sigma_s, float sigma_r) // img 和 edge 都必须是CV_32F类型
  2. {
  3. float invSpatialStdev = 1.0f / sigma_s;
  4. float invColorStdev = 1.0f / sigma_r;
  5. // Construct the position vectors out of x, y, r, g, and b.
  6. int height = img.rows;
  7. int width = img.cols;
  8. int eCh = edge.channels(); // 1 或 3
  9. int iCh = img.channels();
  10. Image positions(1, width, height, 2 + eCh); // 只有一个子窗口
  11. Image input(1, width, height, iCh);
  12. //From Mat to Image
  13. for (int y = 0; y < height; y++)
  14. {
  15. float *pimg = img.ptr<float>(y);
  16. float *pedge = edge.ptr<float>(y);
  17. for (int x = 0; x < width; x++)
  18. {
  19. // 参考论文 p4 3.1
  20. // 5维的 positiion vector
  21. positions(x, y)[0] = invSpatialStdev * x; // 0
  22. positions(x, y)[1] = invSpatialStdev * y; // 1
  23. for(int c = 0; c < eCh; c++)
  24. positions(x, y)[2 + c] = invColorStdev * pedge[x * eCh + c]; // 2+
  25. // 3维的 input vector
  26. for(int c = 0; c < iCh; c++)
  27. input(x, y)[c] = pimg[x * iCh + c];
  28. }
  29. }
  30. // Filter the input with respect to the position vectors. (see permutohedral.h)
  31. Image out = PermutohedralLattice::filter(input, positions);
  32. // Save the result
  33. Mat imgOut(img.size(), img.type());
  34. for (int y = 0; y < height; y++)
  35. {
  36. float *pimgOut = imgOut.ptr<float>(y);
  37. for (int x = 0; x < width; x++)
  38. {
  39. for(int c = 0; c < iCh; c++)
  40. pimgOut[x * iCh + c] = out(x, y)[c];
  41. }
  42. }
  43. return imgOut;
  44. }


2. PermutohedralLattice 类:


  
  1. /***************************************************************/
  2. /* The algorithm class that performs the filter
  3. *
  4. * PermutohedralLattice::filter(...) does all the work.
  5. *
  6. */
  7. /***************************************************************/
  8. class PermutohedralLattice
  9. {
  10. public:
  11. /* Filters given image against a reference image.
  12. * im : image to be bilateral-filtered. (input vector)
  13. * ref : reference image whose edges are to be respected. (position vector)
  14. */
  15. static Image filter(Image im, Image ref)
  16. {
  17. //timeval t[5];
  18. // Create lattice
  19. // gettimeofday(t+0, NULL);
  20. // d = ref.channels (5)
  21. // vd = im.channels + 1 (3+1)
  22. PermutohedralLattice lattice(ref.channels, im.channels + 1, im.width * im.height * im.frames);
  23. // Splat into the lattice
  24. // gettimeofday(t+1, NULL);
  25. // printf("Splatting...\n");
  26. float *col = new float[im.channels + 1];
  27. col[im.channels] = 1; // homogeneous coordinate
  28. float *imPtr = im(0, 0, 0);
  29. float *refPtr = ref(0, 0, 0); // position vector
  30. for (int t = 0; t < im.frames; t++)
  31. {
  32. for (int y = 0; y < im.height; y++)
  33. {
  34. for (int x = 0; x < im.width; x++)
  35. {
  36. for (int c = 0; c < im.channels; c++)
  37. {
  38. col[c] = *imPtr++;
  39. }
  40. lattice.splat(refPtr, col);
  41. refPtr += ref.channels;
  42. }
  43. }
  44. }
  45. // Blur the lattice
  46. // gettimeofday(t+2, NULL);
  47. // printf("Blurring...");
  48. lattice.blur();
  49. // Slice from the lattice
  50. // gettimeofday(t+3, NULL);
  51. // printf("Slicing...\n");
  52. Image out(im.frames, im.width, im.height, im.channels);
  53. lattice.beginSlice();
  54. float *outPtr = out(0, 0, 0);
  55. for (int t = 0; t < im.frames; t++)
  56. {
  57. for (int y = 0; y < im.height; y++)
  58. {
  59. for (int x = 0; x < im.width; x++)
  60. {
  61. lattice.slice(col);
  62. float scale = 1.0f / col[im.channels];
  63. for (int c = 0; c < im.channels; c++)
  64. {
  65. *outPtr++ = col[c] * scale;
  66. }
  67. }
  68. }
  69. }
  70. // Print time elapsed for each step
  71. // gettimeofday(t+4, NULL);
  72. // const char *names[4] = {"Init ", "Splat ", "Blur ", "Slice "};
  73. // for (int i = 1; i < 5; i++)
  74. // printf("%s: %3.3f ms\n", names[i-1], (t[i].tv_sec - t[i-1].tv_sec) +
  75. // (t[i].tv_usec - t[i-1].tv_usec)/1000000.0);
  76. return out;
  77. }
  78. /* Constructor
  79. * d_ : dimensionality of key vectors (ref.channels)
  80. * vd_ : dimensionality of value vectors (im.channels + 1)
  81. * nData_ : number of points in the input (im.size * im.frames)
  82. */
  83. PermutohedralLattice(int d_, int vd_, int nData_) :
  84. d(d_), vd(vd_), nData(nData_), hashTable(d_, vd_)
  85. {
  86. // Allocate storage for various arrays
  87. elevated = new float[d + 1];
  88. scaleFactor = new float[d];
  89. greedy = new short[d + 1];
  90. rank = new char[d + 1];
  91. barycentric = new float[d + 2];
  92. replay = new ReplayEntry[nData * (d + 1)];
  93. nReplay = 0;
  94. canonical = new short[(d + 1) * (d + 1)];
  95. key = new short[d + 1];
  96. // compute the coordinates of the canonical simplex, in which
  97. // the difference between a contained point and the zero
  98. // remainder vertex is always in ascending order. (See pg.4 of paper.)
  99. // 论文第四页,d=4的矩阵例子(列主序)
  100. for (int i = 0; i <= d; i++)
  101. {
  102. for (int j = 0; j <= d - i; j++)
  103. canonical[i * (d + 1) + j] = i;
  104. for (int j = d - i + 1; j <= d; j++)
  105. canonical[i * (d + 1) + j] = i - (d + 1);
  106. }
  107. // Compute parts of the rotation matrix E. (See pg.4-5 of paper.)
  108. for (int i = 0; i < d; i++)
  109. {
  110. // the diagonal entries for normalization
  111. scaleFactor[i] = 1.0f / (sqrtf( (float)(i + 1) * (i + 2) ));
  112. /* We presume that the user would like to do a Gaussian blur of standard deviation
  113. * 1 in each dimension (or a total variance of d, summed over dimensions.)
  114. * Because the total variance of the blur performed by this algorithm is not d,
  115. * we must scale the space to offset this.
  116. *
  117. * The total variance of the algorithm is (See pg.6 and 10 of paper):
  118. * [variance of splatting] + [variance of blurring] + [variance of splatting]
  119. * = d(d+1)(d+1)/12 + d(d+1)(d+1)/2 + d(d+1)(d+1)/12
  120. * = 2d(d+1)(d+1)/3.
  121. *
  122. * So we need to scale the space by (d+1)sqrt(2/3).
  123. */
  124. // 论文 第四页 scale position vector
  125. scaleFactor[i] *= (d + 1) * sqrtf(2.0 / 3);
  126. }
  127. }
  128. /* Performs splatting with given position and value vectors */
  129. // position: d-dimension position vector
  130. // value: [r, g, b, 1]
  131. void splat(float *position, float *value)
  132. {
  133. // first rotate position into the (d+1)-dimensional hyperplane
  134. // 论文 第五页 Ex计算
  135. elevated[d] = -d * position[d - 1] * scaleFactor[d - 1];
  136. for (int i = d - 1; i > 0; i--)
  137. elevated[i] = (elevated[i + 1] -
  138. i * position[i - 1] * scaleFactor[i - 1] +
  139. (i + 2) * position[i] * scaleFactor[i]);
  140. elevated[0] = elevated[1] + 2 * position[0] * scaleFactor[0];
  141. // prepare to find the closest lattice points
  142. float scale = 1.0f / (d + 1);
  143. char *myrank = rank;
  144. short *mygreedy = greedy;
  145. // greedily search for the closest zero-colored lattice point
  146. // 论文 第三页
  147. int sum = 0;
  148. for (int i = 0; i <= d; i++)
  149. {
  150. float v = elevated[i] * scale;
  151. float up = ceilf(v) * (d + 1); // 查找最近的整数点,up / down
  152. float down = floorf(v) * (d + 1);
  153. if (up - elevated[i] < elevated[i] - down)
  154. mygreedy[i] = (short)up;
  155. else
  156. mygreedy[i] = (short)down;
  157. sum += mygreedy[i];
  158. }
  159. sum /= d + 1; // consistent remainder (d+1)
  160. // rank differential to find the permutation between this simplex and the canonical one.
  161. // (See pg. 3-4 in paper.)
  162. // 相对差值小的rank++
  163. memset(myrank, 0, sizeof(char) * (d + 1));
  164. for (int i = 0; i < d; i++)
  165. for (int j = i + 1; j <= d; j++)
  166. if (elevated[i] - mygreedy[i] < elevated[j] - mygreedy[j])
  167. myrank[i]++;
  168. else
  169. myrank[j]++;
  170. if (sum > 0)
  171. {
  172. // sum too large - the point is off the hyperplane.
  173. // need to bring down the ones with the smallest differential
  174. for (int i = 0; i <= d; i++)
  175. {
  176. if (myrank[i] >= d + 1 - sum)
  177. {
  178. mygreedy[i] -= d + 1;
  179. myrank[i] += sum - (d + 1);
  180. }
  181. else
  182. myrank[i] += sum;
  183. }
  184. }
  185. else if (sum < 0)
  186. {
  187. // sum too small - the point is off the hyperplane
  188. // need to bring up the ones with largest differential
  189. for (int i = 0; i <= d; i++)
  190. {
  191. if (myrank[i] < -sum)
  192. {
  193. mygreedy[i] += d + 1;
  194. myrank[i] += (d + 1) + sum;
  195. }
  196. else
  197. myrank[i] += sum;
  198. }
  199. }
  200. // Compute barycentric coordinates (See pg.10 of paper.)
  201. memset(barycentric, 0, sizeof(float) * (d + 2));
  202. for (int i = 0; i <= d; i++)
  203. {
  204. barycentric[d - myrank[i]] += (elevated[i] - mygreedy[i]) * scale;
  205. barycentric[d + 1 - myrank[i]] -= (elevated[i] - mygreedy[i]) * scale;
  206. }
  207. barycentric[0] += 1.0f + barycentric[d + 1];
  208. // Splat the value into each vertex of the simplex, with barycentric weights.
  209. for (int remainder = 0; remainder <= d; remainder++)
  210. {
  211. // Compute the location of the lattice point explicitly (all but the last coordinate - it's redundant because they sum to zero)
  212. for (int i = 0; i < d; i++)
  213. key[i] = mygreedy[i] + canonical[remainder * (d + 1) + myrank[i]];
  214. // Retrieve pointer to the value at this vertex.
  215. float *val = hashTable.lookup(key, true);
  216. // Accumulate values with barycentric weight.
  217. for (int i = 0; i < vd; i++)
  218. val[i] += barycentric[remainder] * value[i];
  219. // Record this interaction to use later when slicing
  220. replay[nReplay].offset = val - hashTable.getValues();
  221. replay[nReplay].weight = barycentric[remainder];
  222. nReplay++;
  223. }
  224. }
  225. // Prepare for slicing
  226. void beginSlice()
  227. {
  228. nReplay = 0;
  229. }
  230. /* Performs slicing out of position vectors. Note that the barycentric weights and the simplex
  231. * containing each position vector were calculated and stored in the splatting step.
  232. * We may reuse this to accelerate the algorithm. (See pg. 6 in paper.)
  233. */
  234. void slice(float *col)
  235. {
  236. float *base = hashTable.getValues();
  237. for (int j = 0; j < vd; j++)
  238. col[j] = 0;
  239. for (int i = 0; i <= d; i++)
  240. {
  241. ReplayEntry r = replay[nReplay++];
  242. for (int j = 0; j < vd; j++)
  243. {
  244. col[j] += r.weight * base[r.offset + j];
  245. }
  246. }
  247. }
  248. /* Performs a Gaussian blur along each projected axis in the hyperplane. */
  249. void blur()
  250. {
  251. // Prepare arrays
  252. short *neighbor1 = new short[d + 1];
  253. short *neighbor2 = new short[d + 1];
  254. float *newValue = new float[vd * hashTable.size()];
  255. float *oldValue = hashTable.getValues();
  256. float *hashTableBase = oldValue;
  257. float *zero = new float[vd];
  258. for (int k = 0; k < vd; k++)
  259. zero[k] = 0;
  260. // For each of d+1 axes,
  261. for (int j = 0; j <= d; j++)
  262. {
  263. printf("blur %d\t", j);
  264. fflush(stdout);
  265. // For each vertex in the lattice,
  266. for (int i = 0; i < hashTable.size(); i++) // blur point i in dimension j
  267. {
  268. short *key = hashTable.getKeys() + i * (d); // keys to current vertex
  269. for (int k = 0; k < d; k++)
  270. {
  271. neighbor1[k] = key[k] + 1;
  272. neighbor2[k] = key[k] - 1;
  273. }
  274. neighbor1[j] = key[j] - d;
  275. neighbor2[j] = key[j] + d; // keys to the neighbors along the given axis.
  276. float *oldVal = oldValue + i * vd;
  277. float *newVal = newValue + i * vd;
  278. float *vm1, *vp1;
  279. //printf("first neighbor\n");
  280. vm1 = hashTable.lookup(neighbor1, false); // look up first neighbor
  281. if (vm1)
  282. vm1 = vm1 - hashTableBase + oldValue;
  283. else
  284. vm1 = zero;
  285. //printf("second neighbor\n");
  286. vp1 = hashTable.lookup(neighbor2, false); // look up second neighbor
  287. if (vp1)
  288. vp1 = vp1 - hashTableBase + oldValue;
  289. else
  290. vp1 = zero;
  291. // Mix values of the three vertices
  292. for (int k = 0; k < vd; k++)
  293. newVal[k] = (0.25f * vm1[k] + 0.5f * oldVal[k] + 0.25f * vp1[k]);
  294. }
  295. float *tmp = newValue;
  296. newValue = oldValue;
  297. oldValue = tmp;
  298. // the freshest data is now in oldValue, and newValue is ready to be written over
  299. }
  300. // depending where we ended up, we may have to copy data
  301. if (oldValue != hashTableBase)
  302. {
  303. memcpy(hashTableBase, oldValue, hashTable.size()*vd * sizeof(float));
  304. delete oldValue;
  305. }
  306. else
  307. {
  308. delete newValue;
  309. }
  310. printf("\n");
  311. delete zero;
  312. delete neighbor1;
  313. delete neighbor2;
  314. }
  315. private:
  316. int d, vd, nData;
  317. float *elevated, *scaleFactor, *barycentric;
  318. short *canonical;
  319. short *key;
  320. // slicing is done by replaying splatting (ie storing the sparse matrix)
  321. struct ReplayEntry
  322. {
  323. int offset;
  324. float weight;
  325. } *replay;
  326. int nReplay, nReplaySub;
  327. public:
  328. char *rank;
  329. short *greedy;
  330. HashTablePermutohedral hashTable;
  331. };


3. 用于permutohedral lattice的哈希表:


  
  1. /***************************************************************/
  2. /* Hash table implementation for permutohedral lattice
  3. *
  4. * The lattice points are stored sparsely using a hash table.
  5. * The key for each point is its spatial location in the (d+1)-
  6. * dimensional space.
  7. */
  8. /***************************************************************/
  9. class HashTablePermutohedral
  10. {
  11. public:
  12. /* Constructor
  13. * kd_: the dimensionality of the position vectors on the hyperplane.
  14. * vd_: the dimensionality of the value vectors
  15. */
  16. HashTablePermutohedral(int kd_, int vd_) : kd(kd_), vd(vd_)
  17. {
  18. capacity = 1 << 15;
  19. filled = 0;
  20. entries = new Entry[capacity];
  21. keys = new short[kd * capacity / 2]; // 多维 键-值对
  22. values = new float[vd * capacity / 2];
  23. memset(values, 0, sizeof(float)*vd * capacity / 2);
  24. }
  25. // Returns the number of vectors stored.
  26. int size()
  27. {
  28. return filled;
  29. }
  30. // Returns a pointer to the keys array.
  31. short *getKeys()
  32. {
  33. return keys;
  34. }
  35. // Returns a pointer to the values array.
  36. float *getValues()
  37. {
  38. return values;
  39. }
  40. /* Returns the index into the hash table for a given key.
  41. * key: a pointer to the position vector.
  42. * h: hash of the position vector.
  43. * create: a flag specifying whether an entry should be created,
  44. * should an entry with the given key not found.
  45. */
  46. // 返回 value 指针的偏移量
  47. int lookupOffset(short *key, size_t h, bool create = true)
  48. {
  49. // Double hash table size if necessary
  50. // 如果存储的数据达到或超过容量的一半
  51. if (filled >= (capacity / 2) - 1)
  52. {
  53. grow();
  54. }
  55. // Find the entry with the given key
  56. // 根据给定的 hash 索引 entry
  57. while (1)
  58. {
  59. Entry e = entries[h];
  60. // check if the cell is empty
  61. // 检查该 entry 的 key 是否存在
  62. if (e.keyIdx == -1)
  63. {
  64. if (!create)
  65. return -1; // Return not found.
  66. // need to create an entry. Store the given key.
  67. for (int i = 0; i < kd; i++)
  68. keys[filled * kd + i] = key[i];
  69. e.keyIdx = filled * kd;
  70. e.valueIdx = filled * vd;
  71. entries[h] = e;
  72. filled++;
  73. return e.valueIdx;
  74. }
  75. // check if the cell has a matching key
  76. bool match = true;
  77. for (int i = 0; i < kd && match; i++)
  78. match = keys[e.keyIdx + i] == key[i];
  79. if (match)
  80. return e.valueIdx;
  81. // increment the bucket with wraparound
  82. // 顺序查找下一个 entry 【计算出的hash值相同的情况】
  83. h++;
  84. // 如果到达最后一个 entry, 则从第一个 entry 开始找
  85. if (h == capacity)
  86. h = 0;
  87. }
  88. }
  89. /* Looks up the value vector associated with a given key vector.
  90. * k : pointer to the key vector to be looked up.
  91. * create : true if a non-existing key should be created.
  92. */
  93. float *lookup(short *k, bool create = true)
  94. {
  95. size_t h = hash(k) % capacity;
  96. int offset = lookupOffset(k, h, create);
  97. if (offset < 0)
  98. return NULL;
  99. else
  100. return values + offset;
  101. };
  102. /* Hash function used in this implementation. A simple base conversion. */
  103. size_t hash(const short *key)
  104. {
  105. size_t k = 0;
  106. for (int i = 0; i < kd; i++)
  107. {
  108. k += key[i];
  109. k *= 2531011;
  110. }
  111. return k;
  112. }
  113. private:
  114. /* Grows the size of the hash table */
  115. void grow()
  116. {
  117. printf("Resizing hash table\n");
  118. size_t oldCapacity = capacity;
  119. capacity *= 2; // 变为2倍容量
  120. // Migrate the value vectors.
  121. float *newValues = new float[vd * capacity / 2];
  122. memset(newValues, 0, sizeof(float)*vd * capacity / 2);
  123. memcpy(newValues, values, sizeof(float)*vd * filled);
  124. delete[] values;
  125. values = newValues;
  126. // Migrate the key vectors.
  127. short *newKeys = new short[kd * capacity / 2];
  128. memcpy(newKeys, keys, sizeof(short)*kd * filled);
  129. delete[] keys;
  130. keys = newKeys;
  131. Entry *newEntries = new Entry[capacity];
  132. // Migrate the table of indices.
  133. for (size_t i = 0; i < oldCapacity; i++)
  134. {
  135. if (entries[i].keyIdx == -1)
  136. continue;
  137. // 根据键值计算hash
  138. size_t h = hash(keys + entries[i].keyIdx) % capacity;
  139. // 如果hash对应entry的keyidx已经被占用,则顺序往后找 entry,直到发现该 entry 的 keyidx 未被占用
  140. while (newEntries[h].keyIdx != -1)
  141. {
  142. h++;
  143. if (h == capacity)
  144. h = 0;
  145. }
  146. newEntries[h] = entries[i];
  147. }
  148. delete[] entries;
  149. entries = newEntries;
  150. }
  151. // Private struct for the hash table entries.
  152. struct Entry
  153. {
  154. Entry() : keyIdx(-1), valueIdx(-1) {}
  155. int keyIdx; // keys 的索引
  156. int valueIdx; // values 的索引
  157. };
  158. short *keys;
  159. float *values;
  160. Entry *entries;
  161. size_t capacity, filled; // 分别表示 entry 的容量 和 已填充的 entry 数
  162. int kd, vd; // keys 和 values 数组的维度(PermutohedraLattice 会将数据 splat 到高维空间)
  163. };


效果图:





文章来源: panda1234lee.blog.csdn.net,作者:panda1234lee,版权归原作者所有,如需转载,请联系作者。

原文链接:panda1234lee.blog.csdn.net/article/details/52892359

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。