11 #ifndef vtk_m_worklet_waveletcompressor_h
12 #define vtk_m_worklet_waveletcompressor_h
14 #include <vtkm/worklet/wavelets/WaveletDWT.h>
34 template <
typename SignalArrayType,
typename CoeffArrayType>
37 CoeffArrayType& coeffOut,
38 std::vector<vtkm::Id>& L)
40 vtkm::Id sigInLen = sigIn.GetNumberOfValues();
41 if (nLevels < 0 || nLevels > WaveletBase::GetWaveletMaxLevel(sigInLen))
51 this->
ComputeL(sigInLen, nLevels, L);
57 vtkm::Id cALen = WaveletBase::GetApproxLength(len);
60 std::vector<vtkm::Id> L1d(3, 0);
63 using OutputValueType =
typename CoeffArrayType::ValueType;
72 for (
vtkm::Id i = nLevels; i > 0; i--)
75 cptr = 0 + CLength - tlen - cALen;
78 IdArrayType inputIndices(sigInPtr, 1, len);
79 PermutArrayType input(inputIndices, coeffOut);
81 InterArrayType output;
83 WaveletDWT::DWT1D(input, output, L1d);
86 WaveletBase::DeviceCopyStartX(output, coeffOut, cptr);
90 cALen = WaveletBase::GetApproxLength(cALen);
98 template <
typename CoeffArrayType,
typename SignalArrayType>
101 std::vector<vtkm::Id>& L,
102 SignalArrayType& sigOut)
108 std::vector<vtkm::Id> L1d(3, 0);
112 using OutValueType =
typename SignalArrayType::ValueType;
119 for (
vtkm::Id i = 1; i <= nLevels; i++)
124 IdArrayType inputIndices(0, 1, L1d[2]);
125 PermutArrayType input(inputIndices, sigOut);
128 OutArrayBasic output;
130 WaveletDWT::IDWT1D(input, L1d, output);
134 WaveletBase::DeviceCopyStartX(output, sigOut, 0);
137 L1d[1] = L[size_t(i + 1)];
144 template <
typename InArrayType,
typename OutArrayType>
150 OutArrayType& coeffOut,
153 vtkm::Id sigInLen = sigIn.GetNumberOfValues();
155 if (nLevels < 0 || nLevels > WaveletBase::GetWaveletMaxLevel(inX) ||
156 nLevels > WaveletBase::GetWaveletMaxLevel(inY) ||
157 nLevels > WaveletBase::GetWaveletMaxLevel(inZ))
171 using OutValueType =
typename OutArrayType::ValueType;
176 sigIn, inX, inY, inZ, 0, 0, 0, currentLenX, currentLenY, currentLenZ, coeffOut, discardSigIn);
179 for (
vtkm::Id i = nLevels - 1; i > 0; i--)
181 currentLenX = WaveletBase::GetApproxLength(currentLenX);
182 currentLenY = WaveletBase::GetApproxLength(currentLenY);
183 currentLenZ = WaveletBase::GetApproxLength(currentLenZ);
185 OutBasicArray tempOutput;
187 computationTime += WaveletDWT::DWT3D(
188 coeffOut, inX, inY, inZ, 0, 0, 0, currentLenX, currentLenY, currentLenZ, tempOutput,
false);
191 WaveletBase::DeviceCubeCopyTo(
192 tempOutput, currentLenX, currentLenY, currentLenZ, coeffOut, inX, inY, inZ, 0, 0, 0);
195 return computationTime;
199 template <
typename InArrayType,
typename OutArrayType>
206 OutArrayType& arrOut,
209 vtkm::Id arrInLen = arrIn.GetNumberOfValues();
211 if (nLevels < 0 || nLevels > WaveletBase::GetWaveletMaxLevel(inX) ||
212 nLevels > WaveletBase::GetWaveletMaxLevel(inY) ||
213 nLevels > WaveletBase::GetWaveletMaxLevel(inZ))
217 using OutValueType =
typename OutArrayType::ValueType;
221 OutBasicArray outBuffer;
227 else if (discardArrIn)
236 std::vector<vtkm::Id> L;
237 this->
ComputeL3(inX, inY, inZ, nLevels, L);
238 std::vector<vtkm::Id> L3d(27, 0);
241 for (
size_t i = 0; i < 24; i++)
245 for (
size_t i = 1; i < static_cast<size_t>(nLevels); i++)
247 L3d[24] = L3d[0] + L3d[12];
248 L3d[25] = L3d[1] + L3d[7];
249 L3d[26] = L3d[2] + L3d[5];
251 OutBasicArray tempOutput;
255 WaveletDWT::IDWT3D(outBuffer, inX, inY, inZ, 0, 0, 0, L3d, tempOutput,
false);
258 WaveletBase::DeviceCubeCopyTo(
259 tempOutput, L3d[24], L3d[25], L3d[26], outBuffer, inX, inY, inZ, 0, 0, 0);
265 for (
size_t j = 3; j < 24; j++)
267 L3d[j] = L[21 * i + j];
272 L3d[24] = L3d[0] + L3d[12];
273 L3d[25] = L3d[1] + L3d[7];
274 L3d[26] = L3d[2] + L3d[5];
275 computationTime += WaveletDWT::IDWT3D(outBuffer, inX, inY, inZ, 0, 0, 0, L3d, arrOut,
true);
277 return computationTime;
281 template <
typename InArrayType,
typename OutArrayType>
286 OutArrayType& coeffOut,
287 std::vector<vtkm::Id>& L)
289 vtkm::Id sigInLen = sigIn.GetNumberOfValues();
291 if (nLevels < 0 || nLevels > WaveletBase::GetWaveletMaxLevel(inX) ||
292 nLevels > WaveletBase::GetWaveletMaxLevel(inY))
308 std::vector<vtkm::Id> L2d(10, 0);
311 using OutValueType =
typename OutArrayType::ValueType;
315 computationTime += WaveletDWT::DWT2D(
316 sigIn, currentLenX, currentLenY, 0, 0, currentLenX, currentLenY, coeffOut, L2d);
317 VTKM_ASSERT(coeffOut.GetNumberOfValues() == currentLenX * currentLenY);
318 currentLenX = WaveletBase::GetApproxLength(currentLenX);
319 currentLenY = WaveletBase::GetApproxLength(currentLenY);
322 for (
vtkm::Id i = nLevels - 1; i > 0; i--)
324 OutBasicArray tempOutput;
327 WaveletDWT::DWT2D(coeffOut, inX, inY, 0, 0, currentLenX, currentLenY, tempOutput, L2d);
330 WaveletBase::DeviceRectangleCopyTo(
331 tempOutput, currentLenX, currentLenY, coeffOut, inX, inY, 0, 0);
334 currentLenX = WaveletBase::GetApproxLength(currentLenX);
335 currentLenY = WaveletBase::GetApproxLength(currentLenY);
338 return computationTime;
342 template <
typename InArrayType,
typename OutArrayType>
347 OutArrayType& arrOut,
348 std::vector<vtkm::Id>& L)
350 vtkm::Id arrInLen = arrIn.GetNumberOfValues();
352 if (nLevels < 0 || nLevels > WaveletBase::GetWaveletMaxLevel(inX) ||
353 nLevels > WaveletBase::GetWaveletMaxLevel(inY))
357 using OutValueType =
typename OutArrayType::ValueType;
361 OutBasicArray outBuffer;
374 std::vector<vtkm::Id> L2d(10, 0);
385 for (
size_t i = 1; i < static_cast<size_t>(nLevels); i++)
387 L2d[8] = L2d[0] + L2d[4];
388 L2d[9] = L2d[1] + L2d[3];
390 OutBasicArray tempOutput;
393 computationTime += WaveletDWT::IDWT2D(outBuffer, inX, inY, 0, 0, L2d, tempOutput);
396 WaveletBase::DeviceRectangleCopyTo(tempOutput, L2d[8], L2d[9], outBuffer, inX, inY, 0, 0);
401 L2d[2] = L[6 * i + 2];
402 L2d[3] = L[6 * i + 3];
403 L2d[4] = L[6 * i + 4];
404 L2d[5] = L[6 * i + 5];
405 L2d[6] = L[6 * i + 6];
406 L2d[7] = L[6 * i + 7];
410 L2d[8] = L2d[0] + L2d[4];
411 L2d[9] = L2d[1] + L2d[3];
412 computationTime += WaveletDWT::IDWT2D(outBuffer, inX, inY, 0, 0, L2d, arrOut);
414 return computationTime;
418 template <
typename CoeffArrayType>
423 vtkm::Id coeffLen = coeffIn.GetNumberOfValues();
424 using ValueType =
typename CoeffArrayType::ValueType;
426 CoeffArrayBasic sortedArray;
429 WaveletBase::DeviceSort(sortedArray);
438 using ThresholdType = vtkm::worklet::wavelets::ThresholdWorklet;
439 ThresholdType thresholdWorklet(nthVal);
441 dispatcher.Invoke(coeffIn);
448 template <
typename ArrayType>
451 #define VAL vtkm::Float64
452 #define MAKEVAL(a) (static_cast<VAL>(a))
453 VAL VarOrig = WaveletBase::DeviceCalculateVariance(original);
455 using ValueType =
typename ArrayType::ValueType;
457 ArrayBasic errorArray, errorSquare;
460 using DifferencerWorklet = vtkm::worklet::wavelets::Differencer;
461 DifferencerWorklet dw;
463 dwDispatcher.Invoke(original, reconstruct, errorArray);
465 using SquareWorklet = vtkm::worklet::wavelets::SquareWorklet;
468 swDispatcher.Invoke(errorArray, errorSquare);
470 VAL varErr = WaveletBase::DeviceCalculateVariance(errorArray);
474 snr = VarOrig / varErr;
479 snr = vtkm::Infinity64();
480 decibels = vtkm::Infinity64();
483 VAL origMax = WaveletBase::DeviceMax(original);
484 VAL origMin = WaveletBase::DeviceMin(original);
485 VAL errorMax = WaveletBase::DeviceMaxAbs(errorArray);
486 VAL range = origMax - origMin;
488 VAL squareSum = WaveletBase::DeviceSum(errorSquare);
491 std::cout <<
"Data range = " << range << std::endl;
492 std::cout <<
"SNR = " << snr << std::endl;
493 std::cout <<
"SNR in decibels = " << decibels << std::endl;
494 std::cout <<
"L-infy norm = " << errorMax
495 <<
", after normalization = " << errorMax / range << std::endl;
496 std::cout <<
"RMSE = " << rmse <<
", after normalization = " << rmse / range
507 size_t nLevels =
static_cast<size_t>(nLev);
508 L.resize(nLevels + 2);
509 L[nLevels + 1] = sigInLen;
510 L[nLevels] = sigInLen;
511 for (
size_t i = nLevels; i > 0; i--)
513 L[i - 1] = WaveletBase::GetApproxLength(L[i]);
514 L[i] = WaveletBase::GetDetailLength(L[i]);
521 size_t nLevels =
static_cast<size_t>(nLev);
522 L.resize(nLevels * 6 + 4);
523 L[nLevels * 6] = inX;
524 L[nLevels * 6 + 1] = inY;
525 L[nLevels * 6 + 2] = inX;
526 L[nLevels * 6 + 3] = inY;
528 for (
size_t i = nLevels; i > 0; i--)
531 L[i * 6 - 6] = WaveletBase::GetApproxLength(L[i * 6 + 0]);
532 L[i * 6 - 5] = WaveletBase::GetApproxLength(L[i * 6 + 1]);
535 L[i * 6 - 4] = WaveletBase::GetApproxLength(L[i * 6 + 0]);
536 L[i * 6 - 3] = WaveletBase::GetDetailLength(L[i * 6 + 1]);
539 L[i * 6 - 2] = WaveletBase::GetDetailLength(L[i * 6 + 0]);
540 L[i * 6 - 1] = WaveletBase::GetApproxLength(L[i * 6 + 1]);
543 L[i * 6 - 0] = WaveletBase::GetDetailLength(L[i * 6 + 0]);
544 L[i * 6 + 1] = WaveletBase::GetDetailLength(L[i * 6 + 1]);
551 size_t n =
static_cast<size_t>(nLev);
552 L.resize(n * 21 + 6);
560 for (
size_t i = n; i > 0; i--)
563 L[i * 21 - 21] = WaveletBase::GetApproxLength(L[i * 21 + 0]);
564 L[i * 21 - 20] = WaveletBase::GetApproxLength(L[i * 21 + 1]);
565 L[i * 21 - 19] = WaveletBase::GetApproxLength(L[i * 21 + 2]);
568 L[i * 21 - 18] = L[i * 21 - 21];
569 L[i * 21 - 17] = L[i * 21 - 20];
570 L[i * 21 - 16] = WaveletBase::GetDetailLength(L[i * 21 + 2]);
573 L[i * 21 - 15] = L[i * 21 - 21];
574 L[i * 21 - 14] = WaveletBase::GetDetailLength(L[i * 21 + 1]);
575 L[i * 21 - 13] = L[i * 21 - 19];
578 L[i * 21 - 12] = L[i * 21 - 21];
579 L[i * 21 - 11] = L[i * 21 - 14];
580 L[i * 21 - 10] = L[i * 21 - 16];
583 L[i * 21 - 9] = WaveletBase::GetDetailLength(L[i * 21 + 0]);
584 L[i * 21 - 8] = L[i * 21 - 20];
585 L[i * 21 - 7] = L[i * 21 - 19];
588 L[i * 21 - 6] = L[i * 21 - 9];
589 L[i * 21 - 5] = L[i * 21 - 20];
590 L[i * 21 - 3] = L[i * 21 - 16];
593 L[i * 21 - 3] = L[i * 21 - 9];
594 L[i * 21 - 2] = L[i * 21 - 14];
595 L[i * 21 - 1] = L[i * 21 - 19];
598 L[i * 21 + 0] = L[i * 21 - 9];
599 L[i * 21 + 1] = L[i * 21 - 14];
600 L[i * 21 + 2] = L[i * 21 - 16];
608 for (
size_t i = 1; i <= size_t(nLevels); i++)
618 for (
size_t i = 1; i <= size_t(nLevels); i++)
620 sum += L[i * 6 - 4] * L[i * 6 - 3];
621 sum += L[i * 6 - 2] * L[i * 6 - 1];
622 sum += L[i * 6 - 0] * L[i * 6 + 1];
633 cALen = WaveletBase::GetApproxLength(cALen);