11 #ifndef vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h
12 #define vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h
33 #include <type_traits>
100 template <
class DerivedAlgorithm,
class DeviceAdapterTag>
101 struct DeviceAdapterAlgorithmGeneral
110 template <
typename T,
class CIn>
121 CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
122 inputPortal, outputPortal, index);
124 DerivedAlgorithm::Schedule(kernel, 1);
133 template <
typename IndicesStorage>
145 auto indicesPortal = indices.
PrepareForOutput(numBits, DeviceAdapterTag{}, token);
147 std::atomic<vtkm::UInt64> popCount;
148 popCount.store(0, std::memory_order_seq_cst);
150 using Functor = BitFieldToUnorderedSetFunctor<decltype(bitsPortal), decltype(indicesPortal)>;
151 Functor functor{ bitsPortal, indicesPortal, popCount };
153 DerivedAlgorithm::Schedule(functor, functor.GetNumberOfInstances());
154 DerivedAlgorithm::Synchronize();
158 numBits =
static_cast<vtkm::Id>(popCount.load(std::memory_order_seq_cst));
166 template <
typename T,
typename U,
class CIn,
class COut>
176 auto outputPortal = output.
PrepareForOutput(inSize, DeviceAdapterTag(), token);
178 CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(inputPortal, outputPortal);
179 DerivedAlgorithm::Schedule(kernel, inSize);
184 template <
typename T,
typename U,
class CIn,
class CStencil,
class COut,
class UnaryPredicate>
188 UnaryPredicate unary_predicate)
196 IndexArrayType indices;
201 auto stencilPortal = stencil.
PrepareForInput(DeviceAdapterTag(), token);
202 auto indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag(), token);
204 StencilToIndexFlagKernel<decltype(stencilPortal), decltype(indexPortal), UnaryPredicate>
205 indexKernel(stencilPortal, indexPortal, unary_predicate);
207 DerivedAlgorithm::Schedule(indexKernel, arrayLength);
210 vtkm::Id outArrayLength = DerivedAlgorithm::ScanExclusive(indices, indices);
216 auto stencilPortal = stencil.
PrepareForInput(DeviceAdapterTag(), token);
217 auto indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag(), token);
218 auto outputPortal = output.
PrepareForOutput(outArrayLength, DeviceAdapterTag(), token);
220 CopyIfKernel<decltype(inputPortal),
221 decltype(stencilPortal),
222 decltype(indexPortal),
223 decltype(outputPortal),
225 copyKernel(inputPortal, stencilPortal, indexPortal, outputPortal, unary_predicate);
226 DerivedAlgorithm::Schedule(copyKernel, arrayLength);
230 template <
typename T,
typename U,
class CIn,
class CStencil,
class COut>
238 DerivedAlgorithm::CopyIf(input, stencil, output, unary_predicate);
243 template <
typename T,
typename U,
class CIn,
class COut>
255 if (input == output &&
256 ((outputIndex >= inputStartIndex &&
257 outputIndex < inputStartIndex + numberOfElementsToCopy) ||
258 (inputStartIndex >= outputIndex &&
259 inputStartIndex < outputIndex + numberOfElementsToCopy)))
264 if (inputStartIndex < 0 || numberOfElementsToCopy < 0 || outputIndex < 0 ||
265 inputStartIndex >= inSize)
271 if (inSize < (inputStartIndex + numberOfElementsToCopy))
273 numberOfElementsToCopy = (inSize - inputStartIndex);
277 const vtkm::Id copyOutEnd = outputIndex + numberOfElementsToCopy;
278 if (outSize < copyOutEnd)
289 DerivedAlgorithm::CopySubRange(output, 0, outSize, temp);
299 CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
300 inputPortal, outputPortal, inputStartIndex, outputIndex);
301 DerivedAlgorithm::Schedule(kernel, numberOfElementsToCopy);
315 std::atomic<vtkm::UInt64> popCount;
316 popCount.store(0, std::memory_order_relaxed);
318 using Functor = CountSetBitsFunctor<decltype(bitsPortal)>;
319 Functor functor{ bitsPortal, popCount };
321 DerivedAlgorithm::Schedule(functor, functor.GetNumberOfInstances());
322 DerivedAlgorithm::Synchronize();
324 return static_cast<vtkm::Id>(popCount.load(std::memory_order_seq_cst));
344 typename vtkm::cont::BitField::template ExecutionTypes<DeviceAdapterTag>::WordTypePreferred;
346 using Functor = FillBitFieldFunctor<decltype(portal), WordType>;
347 Functor functor{ portal, value ? ~WordType{ 0 } : WordType{ 0 } };
349 const vtkm::Id numWords = portal.template GetNumberOfWords<WordType>();
350 DerivedAlgorithm::Schedule(functor, numWords);
370 typename vtkm::cont::BitField::template ExecutionTypes<DeviceAdapterTag>::WordTypePreferred;
372 using Functor = FillBitFieldFunctor<decltype(portal), WordType>;
373 Functor functor{ portal, value ? ~WordType{ 0 } : WordType{ 0 } };
375 const vtkm::Id numWords = portal.template GetNumberOfWords<WordType>();
376 DerivedAlgorithm::Schedule(functor, numWords);
381 template <
typename WordType>
401 auto repWord = RepeatTo32BitsIfNeeded(word);
402 using RepWordType = decltype(repWord);
404 using Functor = FillBitFieldFunctor<decltype(portal), RepWordType>;
405 Functor functor{ portal, repWord };
407 const vtkm::Id numWords = portal.template GetNumberOfWords<RepWordType>();
408 DerivedAlgorithm::Schedule(functor, numWords);
413 template <
typename WordType>
432 auto repWord = RepeatTo32BitsIfNeeded(word);
433 using RepWordType = decltype(repWord);
435 using Functor = FillBitFieldFunctor<decltype(portal), RepWordType>;
436 Functor functor{ portal, repWord };
438 const vtkm::Id numWords = portal.template GetNumberOfWords<RepWordType>();
439 DerivedAlgorithm::Schedule(functor, numWords);
444 template <
typename T,
typename S>
457 auto portal = handle.
PrepareForOutput(numValues, DeviceAdapterTag{}, token);
458 FillArrayHandleFunctor<decltype(portal)> functor{ portal, value };
459 DerivedAlgorithm::Schedule(functor, numValues);
464 template <
typename T,
typename S>
478 auto portal = handle.
PrepareForOutput(numValues, DeviceAdapterTag{}, token);
479 FillArrayHandleFunctor<decltype(portal)> functor{ portal, value };
480 DerivedAlgorithm::Schedule(functor, numValues);
485 template <
typename T,
class CIn,
class CVal,
class COut>
498 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
500 LowerBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
501 inputPortal, valuesPortal, outputPortal);
503 DerivedAlgorithm::Schedule(kernel, arraySize);
506 template <
typename T,
class CIn,
class CVal,
class COut,
class BinaryCompare>
510 BinaryCompare binary_compare)
520 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
522 LowerBoundsComparisonKernel<decltype(inputPortal),
523 decltype(valuesPortal),
524 decltype(outputPortal),
526 kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
528 DerivedAlgorithm::Schedule(kernel, arraySize);
531 template <
class CIn,
class COut>
537 DeviceAdapterAlgorithmGeneral<DerivedAlgorithm, DeviceAdapterTag>::LowerBounds(
538 input, values_output, values_output);
547 template <
typename T,
typename BinaryFunctor>
548 class ReduceDecoratorImpl
551 VTKM_CONT ReduceDecoratorImpl() =
default;
554 ReduceDecoratorImpl(
const T& initialValue,
const BinaryFunctor& binaryFunctor)
555 : InitialValue(initialValue)
556 , ReduceOperator(binaryFunctor)
560 template <
typename Portal>
561 VTKM_CONT ReduceKernel<Portal, T, BinaryFunctor> CreateFunctor(
const Portal& portal)
const
563 return ReduceKernel<Portal, T, BinaryFunctor>(
564 portal, this->InitialValue, this->ReduceOperator);
569 BinaryFunctor ReduceOperator;
573 template <
typename T,
typename U,
class CIn>
578 return DerivedAlgorithm::Reduce(input, initialValue,
vtkm::Add());
581 template <
typename T,
typename U,
class CIn,
class BinaryFunctor>
584 BinaryFunctor binary_functor)
597 length, ReduceDecoratorImpl<U, BinaryFunctor>(initialValue, binary_functor), input);
601 DerivedAlgorithm::ScanInclusive(reduced, inclusiveScanStorage, binary_functor);
607 template <
typename T,
618 BinaryFunctor binary_functor)
627 if (numberOfKeys <= 1)
629 DerivedAlgorithm::Copy(keys, keys_output);
630 DerivedAlgorithm::Copy(values, values_output);
642 auto keyStatePortal = keystate.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
643 ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
644 inputPortal, keyStatePortal);
645 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
661 DerivedAlgorithm::ScanInclusive(
662 scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
670 DerivedAlgorithm::CopyIf(reducedValues, stencil, values_output, ReduceByKeyUnaryStencilOp());
679 DerivedAlgorithm::Copy(keys, keys_output);
680 DerivedAlgorithm::Unique(keys_output);
686 template <
typename T,
class CIn,
class COut,
class BinaryFunctor>
689 BinaryFunctor binaryFunctor,
690 const T& initialValue)
702 T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan, binaryFunctor);
706 auto inputPortal = inclusiveScan.
PrepareForInput(DeviceAdapterTag(), token);
707 auto outputPortal = output.
PrepareForOutput(numValues, DeviceAdapterTag(), token);
709 InclusiveToExclusiveKernel<decltype(inputPortal), decltype(outputPortal), BinaryFunctor>
710 inclusiveToExclusive(inputPortal, outputPortal, binaryFunctor, initialValue);
712 DerivedAlgorithm::Schedule(inclusiveToExclusive, numValues);
714 return binaryFunctor(initialValue, result);
717 template <
typename T,
class CIn,
class COut>
723 return DerivedAlgorithm::ScanExclusive(
729 template <
typename T,
class CIn,
class COut,
class BinaryFunctor>
732 BinaryFunctor binaryFunctor,
733 const T& initialValue)
746 T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan, binaryFunctor);
750 auto inputPortal = inclusiveScan.
PrepareForInput(DeviceAdapterTag(), token);
751 auto outputPortal = output.
PrepareForOutput(numValues + 1, DeviceAdapterTag(), token);
753 InclusiveToExtendedKernel<decltype(inputPortal), decltype(outputPortal), BinaryFunctor>
754 inclusiveToExtended(inputPortal,
758 binaryFunctor(initialValue, result));
760 DerivedAlgorithm::Schedule(inclusiveToExtended, numValues + 1);
763 template <
typename T,
class CIn,
class COut>
769 DerivedAlgorithm::ScanExtended(
775 template <
typename KeyT,
784 const ValueT& initialValue,
785 BinaryFunctor binaryFunctor)
794 if (numberOfKeys == 0)
798 else if (numberOfKeys == 1)
814 auto keyStatePortal = keystate.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
815 ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
816 inputPortal, keyStatePortal);
817 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
825 auto keyStatePortal = keystate.
PrepareForInput(DeviceAdapterTag(), token);
826 auto tempPortal = temp.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
828 ShiftCopyAndInit<ValueT,
829 decltype(inputPortal),
830 decltype(keyStatePortal),
831 decltype(tempPortal)>
832 kernel(inputPortal, keyStatePortal, tempPortal, initialValue);
833 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
836 DerivedAlgorithm::ScanInclusiveByKey(keys, temp, output, binaryFunctor);
839 template <
typename KeyT,
typename ValueT,
class KIn,
typename VIn,
typename VOut>
846 DerivedAlgorithm::ScanExclusiveByKey(
852 template <
typename T,
class CIn,
class COut>
858 return DerivedAlgorithm::ScanInclusive(input, output,
vtkm::Add());
862 template <
typename T1,
typename S1,
typename T2,
typename S2>
869 template <
typename T,
typename S>
877 template <
typename T,
class CIn,
class COut,
class BinaryFunctor>
880 BinaryFunctor binary_functor)
884 if (!ArrayHandlesAreSame(input, output))
886 DerivedAlgorithm::Copy(input, output);
899 using ScanKernelType = ScanKernel<decltype(portal), BinaryFunctor>;
903 for (stride = 2; stride - 1 < numValues; stride *= 2)
905 ScanKernelType kernel(portal, binary_functor, stride, stride / 2 - 1);
906 DerivedAlgorithm::Schedule(kernel, numValues / stride);
910 for (stride /= 2; stride > 1; stride /= 2)
912 ScanKernelType kernel(portal, binary_functor, stride, stride - 1);
913 DerivedAlgorithm::Schedule(kernel, numValues / stride);
917 return GetExecutionValue(output, numValues - 1);
920 template <
typename KeyT,
typename ValueT,
class KIn,
class VIn,
class VOut>
927 return DerivedAlgorithm::ScanInclusiveByKey(keys, values, values_output,
vtkm::Add());
930 template <
typename KeyT,
typename ValueT,
class KIn,
class VIn,
class VOut,
class BinaryFunctor>
934 BinaryFunctor binary_functor)
941 if (numberOfKeys <= 1)
943 DerivedAlgorithm::Copy(values, values_output);
955 auto keyStatePortal = keystate.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
956 ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
957 inputPortal, keyStatePortal);
958 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
973 DerivedAlgorithm::ScanInclusive(
974 scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
977 DerivedAlgorithm::Copy(reducedValues, values_output);
983 template <
typename T,
class Storage,
class BinaryCompare>
985 BinaryCompare binary_compare)
995 while (numThreads < numValues)
1004 using MergeKernel = BitonicSortMergeKernel<decltype(portal), BinaryCompare>;
1005 using CrossoverKernel = BitonicSortCrossoverKernel<decltype(portal), BinaryCompare>;
1007 for (
vtkm::Id crossoverSize = 1; crossoverSize < numValues; crossoverSize *= 2)
1009 DerivedAlgorithm::Schedule(CrossoverKernel(portal, binary_compare, crossoverSize),
1011 for (
vtkm::Id mergeSize = crossoverSize / 2; mergeSize > 0; mergeSize /= 2)
1013 DerivedAlgorithm::Schedule(MergeKernel(portal, binary_compare, mergeSize), numThreads);
1018 template <
typename T,
class Storage>
1023 DerivedAlgorithm::Sort(values, DefaultCompareFunctor());
1028 template <
typename T,
typename U,
class StorageT,
class StorageU>
1038 DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U>());
1041 template <
typename T,
typename U,
class StorageT,
class StorageU,
class BinaryCompare>
1044 BinaryCompare binary_compare)
1053 DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U, BinaryCompare>(binary_compare));
1056 template <
typename T,
1062 typename BinaryFunctor>
1066 BinaryFunctor binaryFunctor)
1078 auto input1Portal = input1.
PrepareForInput(DeviceAdapterTag(), token);
1079 auto input2Portal = input2.
PrepareForInput(DeviceAdapterTag(), token);
1080 auto outputPortal = output.
PrepareForOutput(numValues, DeviceAdapterTag(), token);
1082 BinaryTransformKernel<decltype(input1Portal),
1083 decltype(input2Portal),
1084 decltype(outputPortal),
1086 binaryKernel(input1Portal, input2Portal, outputPortal, binaryFunctor);
1087 DerivedAlgorithm::Schedule(binaryKernel, numValues);
1093 template <
typename T,
class Storage>
1101 template <
typename T,
class Storage,
class BinaryCompare>
1103 BinaryCompare binary_compare)
1110 using WrappedBOpType = internal::WrappedBinaryOperator<bool, BinaryCompare>;
1111 WrappedBOpType wrappedCompare(binary_compare);
1115 auto valuesPortal = values.
PrepareForInput(DeviceAdapterTag(), token);
1116 auto stencilPortal = stencilArray.
PrepareForOutput(inputSize, DeviceAdapterTag(), token);
1117 ClassifyUniqueComparisonKernel<decltype(valuesPortal),
1118 decltype(stencilPortal),
1120 classifyKernel(valuesPortal, stencilPortal, wrappedCompare);
1122 DerivedAlgorithm::Schedule(classifyKernel, inputSize);
1127 DerivedAlgorithm::CopyIf(values, stencilArray, outputArray);
1130 DerivedAlgorithm::Copy(outputArray, values);
1135 template <
typename T,
class CIn,
class CVal,
class COut>
1147 auto valuesPortal = values.
PrepareForInput(DeviceAdapterTag(), token);
1148 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
1150 UpperBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
1151 inputPortal, valuesPortal, outputPortal);
1152 DerivedAlgorithm::Schedule(kernel, arraySize);
1155 template <
typename T,
class CIn,
class CVal,
class COut,
class BinaryCompare>
1159 BinaryCompare binary_compare)
1168 auto valuesPortal = values.
PrepareForInput(DeviceAdapterTag(), token);
1169 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
1171 UpperBoundsKernelComparisonKernel<decltype(inputPortal),
1172 decltype(valuesPortal),
1173 decltype(outputPortal),
1175 kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
1177 DerivedAlgorithm::Schedule(kernel, arraySize);
1180 template <
class CIn,
class COut>
1186 DeviceAdapterAlgorithmGeneral<DerivedAlgorithm, DeviceAdapterTag>::UpperBounds(
1187 input, values_output, values_output);
1203 template <
typename DeviceTag>
1204 class DeviceTaskTypes
1207 template <
typename WorkletType,
typename InvocationType>
1208 static vtkm::exec::internal::TaskSingular<WorkletType, InvocationType>
MakeTask(
1209 WorkletType& worklet,
1210 InvocationType& invocation,
1214 using Task = vtkm::exec::internal::TaskSingular<WorkletType, InvocationType>;
1215 return Task(worklet, invocation, globalIndexOffset);
1218 template <
typename WorkletType,
typename InvocationType>
1219 static vtkm::exec::internal::TaskSingular<WorkletType, InvocationType>
MakeTask(
1220 WorkletType& worklet,
1221 InvocationType& invocation,
1225 using Task = vtkm::exec::internal::TaskSingular<WorkletType, InvocationType>;
1226 return Task(worklet, invocation, globalIndexOffset);
1232 #endif //vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h