11 #ifndef vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h
12 #define vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h
32 #include <type_traits>
89 template <
class DerivedAlgorithm,
class DeviceAdapterTag>
90 struct DeviceAdapterAlgorithmGeneral
99 template <
typename T,
class CIn>
110 CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
111 inputPortal, outputPortal, index);
113 DerivedAlgorithm::Schedule(kernel, 1);
122 template <
typename IndicesStorage>
134 auto indicesPortal = indices.
PrepareForOutput(numBits, DeviceAdapterTag{}, token);
136 std::atomic<vtkm::UInt64> popCount;
137 popCount.store(0, std::memory_order_seq_cst);
139 using Functor = BitFieldToUnorderedSetFunctor<decltype(bitsPortal), decltype(indicesPortal)>;
140 Functor functor{ bitsPortal, indicesPortal, popCount };
142 DerivedAlgorithm::Schedule(functor, functor.GetNumberOfInstances());
143 DerivedAlgorithm::Synchronize();
147 numBits =
static_cast<vtkm::Id>(popCount.load(std::memory_order_seq_cst));
155 template <
typename T,
typename U,
class CIn,
class COut>
165 auto outputPortal = output.
PrepareForOutput(inSize, DeviceAdapterTag(), token);
167 CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(inputPortal, outputPortal);
168 DerivedAlgorithm::Schedule(kernel, inSize);
173 template <
typename T,
typename U,
class CIn,
class CStencil,
class COut,
class UnaryPredicate>
177 UnaryPredicate unary_predicate)
185 IndexArrayType indices;
190 auto stencilPortal = stencil.
PrepareForInput(DeviceAdapterTag(), token);
191 auto indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag(), token);
193 StencilToIndexFlagKernel<decltype(stencilPortal), decltype(indexPortal), UnaryPredicate>
194 indexKernel(stencilPortal, indexPortal, unary_predicate);
196 DerivedAlgorithm::Schedule(indexKernel, arrayLength);
199 vtkm::Id outArrayLength = DerivedAlgorithm::ScanExclusive(indices, indices);
205 auto stencilPortal = stencil.
PrepareForInput(DeviceAdapterTag(), token);
206 auto indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag(), token);
207 auto outputPortal = output.
PrepareForOutput(outArrayLength, DeviceAdapterTag(), token);
209 CopyIfKernel<decltype(inputPortal),
210 decltype(stencilPortal),
211 decltype(indexPortal),
212 decltype(outputPortal),
214 copyKernel(inputPortal, stencilPortal, indexPortal, outputPortal, unary_predicate);
215 DerivedAlgorithm::Schedule(copyKernel, arrayLength);
219 template <
typename T,
typename U,
class CIn,
class CStencil,
class COut>
227 DerivedAlgorithm::CopyIf(input, stencil, output, unary_predicate);
232 template <
typename T,
typename U,
class CIn,
class COut>
244 if (input == output &&
245 ((outputIndex >= inputStartIndex &&
246 outputIndex < inputStartIndex + numberOfElementsToCopy) ||
247 (inputStartIndex >= outputIndex &&
248 inputStartIndex < outputIndex + numberOfElementsToCopy)))
253 if (inputStartIndex < 0 || numberOfElementsToCopy < 0 || outputIndex < 0 ||
254 inputStartIndex >= inSize)
260 if (inSize < (inputStartIndex + numberOfElementsToCopy))
262 numberOfElementsToCopy = (inSize - inputStartIndex);
266 const vtkm::Id copyOutEnd = outputIndex + numberOfElementsToCopy;
267 if (outSize < copyOutEnd)
278 DerivedAlgorithm::CopySubRange(output, 0, outSize, temp);
288 CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
289 inputPortal, outputPortal, inputStartIndex, outputIndex);
290 DerivedAlgorithm::Schedule(kernel, numberOfElementsToCopy);
304 std::atomic<vtkm::UInt64> popCount;
305 popCount.store(0, std::memory_order_relaxed);
307 using Functor = CountSetBitsFunctor<decltype(bitsPortal)>;
308 Functor functor{ bitsPortal, popCount };
310 DerivedAlgorithm::Schedule(functor, functor.GetNumberOfInstances());
311 DerivedAlgorithm::Synchronize();
313 return static_cast<vtkm::Id>(popCount.load(std::memory_order_seq_cst));
333 typename vtkm::cont::BitField::template ExecutionTypes<DeviceAdapterTag>::WordTypePreferred;
335 using Functor = FillBitFieldFunctor<decltype(portal), WordType>;
336 Functor functor{ portal, value ? ~WordType{ 0 } : WordType{ 0 } };
338 const vtkm::Id numWords = portal.template GetNumberOfWords<WordType>();
339 DerivedAlgorithm::Schedule(functor, numWords);
359 typename vtkm::cont::BitField::template ExecutionTypes<DeviceAdapterTag>::WordTypePreferred;
361 using Functor = FillBitFieldFunctor<decltype(portal), WordType>;
362 Functor functor{ portal, value ? ~WordType{ 0 } : WordType{ 0 } };
364 const vtkm::Id numWords = portal.template GetNumberOfWords<WordType>();
365 DerivedAlgorithm::Schedule(functor, numWords);
370 template <
typename WordType>
390 auto repWord = RepeatTo32BitsIfNeeded(word);
391 using RepWordType = decltype(repWord);
393 using Functor = FillBitFieldFunctor<decltype(portal), RepWordType>;
394 Functor functor{ portal, repWord };
396 const vtkm::Id numWords = portal.template GetNumberOfWords<RepWordType>();
397 DerivedAlgorithm::Schedule(functor, numWords);
402 template <
typename WordType>
421 auto repWord = RepeatTo32BitsIfNeeded(word);
422 using RepWordType = decltype(repWord);
424 using Functor = FillBitFieldFunctor<decltype(portal), RepWordType>;
425 Functor functor{ portal, repWord };
427 const vtkm::Id numWords = portal.template GetNumberOfWords<RepWordType>();
428 DerivedAlgorithm::Schedule(functor, numWords);
433 template <
typename T,
typename S>
446 auto portal = handle.
PrepareForOutput(numValues, DeviceAdapterTag{}, token);
447 FillArrayHandleFunctor<decltype(portal)> functor{ portal, value };
448 DerivedAlgorithm::Schedule(functor, numValues);
453 template <
typename T,
typename S>
467 auto portal = handle.
PrepareForOutput(numValues, DeviceAdapterTag{}, token);
468 FillArrayHandleFunctor<decltype(portal)> functor{ portal, value };
469 DerivedAlgorithm::Schedule(functor, numValues);
474 template <
typename T,
class CIn,
class CVal,
class COut>
487 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
489 LowerBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
490 inputPortal, valuesPortal, outputPortal);
492 DerivedAlgorithm::Schedule(kernel, arraySize);
495 template <
typename T,
class CIn,
class CVal,
class COut,
class BinaryCompare>
499 BinaryCompare binary_compare)
509 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
511 LowerBoundsComparisonKernel<decltype(inputPortal),
512 decltype(valuesPortal),
513 decltype(outputPortal),
515 kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
517 DerivedAlgorithm::Schedule(kernel, arraySize);
520 template <
class CIn,
class COut>
526 DeviceAdapterAlgorithmGeneral<DerivedAlgorithm, DeviceAdapterTag>::LowerBounds(
527 input, values_output, values_output);
533 template <
typename T,
typename BinaryFunctor>
534 class ReduceDecoratorImpl
537 VTKM_CONT ReduceDecoratorImpl() =
default;
540 ReduceDecoratorImpl(
const T& initialValue,
const BinaryFunctor& binaryFunctor)
541 : InitialValue(initialValue)
542 , ReduceOperator(binaryFunctor)
546 template <
typename Portal>
547 VTKM_CONT ReduceKernel<Portal, T, BinaryFunctor> CreateFunctor(
const Portal& portal)
const
549 return ReduceKernel<Portal, T, BinaryFunctor>(
550 portal, this->InitialValue, this->ReduceOperator);
555 BinaryFunctor ReduceOperator;
559 template <
typename T,
typename U,
class CIn>
564 return DerivedAlgorithm::Reduce(input, initialValue,
vtkm::Add());
567 template <
typename T,
typename U,
class CIn,
class BinaryFunctor>
570 BinaryFunctor binary_functor)
583 length, ReduceDecoratorImpl<U, BinaryFunctor>(initialValue, binary_functor), input);
587 DerivedAlgorithm::ScanInclusive(reduced, inclusiveScanStorage, binary_functor);
593 template <
typename T,
604 BinaryFunctor binary_functor)
613 if (numberOfKeys <= 1)
615 DerivedAlgorithm::Copy(keys, keys_output);
616 DerivedAlgorithm::Copy(values, values_output);
628 auto keyStatePortal = keystate.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
629 ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
630 inputPortal, keyStatePortal);
631 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
647 DerivedAlgorithm::ScanInclusive(
648 scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
656 DerivedAlgorithm::CopyIf(reducedValues, stencil, values_output, ReduceByKeyUnaryStencilOp());
665 DerivedAlgorithm::Copy(keys, keys_output);
666 DerivedAlgorithm::Unique(keys_output);
672 template <
typename T,
class CIn,
class COut,
class BinaryFunctor>
675 BinaryFunctor binaryFunctor,
676 const T& initialValue)
688 T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan, binaryFunctor);
692 auto inputPortal = inclusiveScan.
PrepareForInput(DeviceAdapterTag(), token);
693 auto outputPortal = output.
PrepareForOutput(numValues, DeviceAdapterTag(), token);
695 InclusiveToExclusiveKernel<decltype(inputPortal), decltype(outputPortal), BinaryFunctor>
696 inclusiveToExclusive(inputPortal, outputPortal, binaryFunctor, initialValue);
698 DerivedAlgorithm::Schedule(inclusiveToExclusive, numValues);
700 return binaryFunctor(initialValue, result);
703 template <
typename T,
class CIn,
class COut>
709 return DerivedAlgorithm::ScanExclusive(
715 template <
typename T,
class CIn,
class COut,
class BinaryFunctor>
718 BinaryFunctor binaryFunctor,
719 const T& initialValue)
732 T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan, binaryFunctor);
736 auto inputPortal = inclusiveScan.
PrepareForInput(DeviceAdapterTag(), token);
737 auto outputPortal = output.
PrepareForOutput(numValues + 1, DeviceAdapterTag(), token);
739 InclusiveToExtendedKernel<decltype(inputPortal), decltype(outputPortal), BinaryFunctor>
740 inclusiveToExtended(inputPortal,
744 binaryFunctor(initialValue, result));
746 DerivedAlgorithm::Schedule(inclusiveToExtended, numValues + 1);
749 template <
typename T,
class CIn,
class COut>
755 DerivedAlgorithm::ScanExtended(
761 template <
typename KeyT,
770 const ValueT& initialValue,
771 BinaryFunctor binaryFunctor)
780 if (numberOfKeys == 0)
784 else if (numberOfKeys == 1)
800 auto keyStatePortal = keystate.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
801 ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
802 inputPortal, keyStatePortal);
803 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
811 auto keyStatePortal = keystate.
PrepareForInput(DeviceAdapterTag(), token);
812 auto tempPortal = temp.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
814 ShiftCopyAndInit<ValueT,
815 decltype(inputPortal),
816 decltype(keyStatePortal),
817 decltype(tempPortal)>
818 kernel(inputPortal, keyStatePortal, tempPortal, initialValue);
819 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
822 DerivedAlgorithm::ScanInclusiveByKey(keys, temp, output, binaryFunctor);
825 template <
typename KeyT,
typename ValueT,
class KIn,
typename VIn,
typename VOut>
832 DerivedAlgorithm::ScanExclusiveByKey(
838 template <
typename T,
class CIn,
class COut>
844 return DerivedAlgorithm::ScanInclusive(input, output,
vtkm::Add());
848 template <
typename T1,
typename S1,
typename T2,
typename S2>
855 template <
typename T,
typename S>
863 template <
typename T,
class CIn,
class COut,
class BinaryFunctor>
866 BinaryFunctor binary_functor)
870 if (!ArrayHandlesAreSame(input, output))
872 DerivedAlgorithm::Copy(input, output);
885 using ScanKernelType = ScanKernel<decltype(portal), BinaryFunctor>;
889 for (stride = 2; stride - 1 < numValues; stride *= 2)
891 ScanKernelType kernel(portal, binary_functor, stride, stride / 2 - 1);
892 DerivedAlgorithm::Schedule(kernel, numValues / stride);
896 for (stride /= 2; stride > 1; stride /= 2)
898 ScanKernelType kernel(portal, binary_functor, stride, stride - 1);
899 DerivedAlgorithm::Schedule(kernel, numValues / stride);
903 return GetExecutionValue(output, numValues - 1);
906 template <
typename KeyT,
typename ValueT,
class KIn,
class VIn,
class VOut>
913 return DerivedAlgorithm::ScanInclusiveByKey(keys, values, values_output,
vtkm::Add());
916 template <
typename KeyT,
typename ValueT,
class KIn,
class VIn,
class VOut,
class BinaryFunctor>
920 BinaryFunctor binary_functor)
927 if (numberOfKeys <= 1)
929 DerivedAlgorithm::Copy(values, values_output);
941 auto keyStatePortal = keystate.
PrepareForOutput(numberOfKeys, DeviceAdapterTag(), token);
942 ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
943 inputPortal, keyStatePortal);
944 DerivedAlgorithm::Schedule(kernel, numberOfKeys);
959 DerivedAlgorithm::ScanInclusive(
960 scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
963 DerivedAlgorithm::Copy(reducedValues, values_output);
969 template <
typename T,
class Storage,
class BinaryCompare>
971 BinaryCompare binary_compare)
981 while (numThreads < numValues)
990 using MergeKernel = BitonicSortMergeKernel<decltype(portal), BinaryCompare>;
991 using CrossoverKernel = BitonicSortCrossoverKernel<decltype(portal), BinaryCompare>;
993 for (
vtkm::Id crossoverSize = 1; crossoverSize < numValues; crossoverSize *= 2)
995 DerivedAlgorithm::Schedule(CrossoverKernel(portal, binary_compare, crossoverSize),
997 for (
vtkm::Id mergeSize = crossoverSize / 2; mergeSize > 0; mergeSize /= 2)
999 DerivedAlgorithm::Schedule(MergeKernel(portal, binary_compare, mergeSize), numThreads);
1004 template <
typename T,
class Storage>
1009 DerivedAlgorithm::Sort(values, DefaultCompareFunctor());
1015 template <
typename T,
typename U,
class StorageT,
class StorageU>
1025 DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U>());
1028 template <
typename T,
typename U,
class StorageT,
class StorageU,
class BinaryCompare>
1031 BinaryCompare binary_compare)
1040 DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U, BinaryCompare>(binary_compare));
1043 template <
typename T,
1049 typename BinaryFunctor>
1053 BinaryFunctor binaryFunctor)
1065 auto input1Portal = input1.
PrepareForInput(DeviceAdapterTag(), token);
1066 auto input2Portal = input2.
PrepareForInput(DeviceAdapterTag(), token);
1067 auto outputPortal = output.
PrepareForOutput(numValues, DeviceAdapterTag(), token);
1069 BinaryTransformKernel<decltype(input1Portal),
1070 decltype(input2Portal),
1071 decltype(outputPortal),
1073 binaryKernel(input1Portal, input2Portal, outputPortal, binaryFunctor);
1074 DerivedAlgorithm::Schedule(binaryKernel, numValues);
1080 template <
typename T,
class Storage>
1088 template <
typename T,
class Storage,
class BinaryCompare>
1090 BinaryCompare binary_compare)
1097 using WrappedBOpType = internal::WrappedBinaryOperator<bool, BinaryCompare>;
1098 WrappedBOpType wrappedCompare(binary_compare);
1102 auto valuesPortal = values.
PrepareForInput(DeviceAdapterTag(), token);
1103 auto stencilPortal = stencilArray.
PrepareForOutput(inputSize, DeviceAdapterTag(), token);
1104 ClassifyUniqueComparisonKernel<decltype(valuesPortal),
1105 decltype(stencilPortal),
1107 classifyKernel(valuesPortal, stencilPortal, wrappedCompare);
1109 DerivedAlgorithm::Schedule(classifyKernel, inputSize);
1114 DerivedAlgorithm::CopyIf(values, stencilArray, outputArray);
1117 DerivedAlgorithm::Copy(outputArray, values);
1122 template <
typename T,
class CIn,
class CVal,
class COut>
1134 auto valuesPortal = values.
PrepareForInput(DeviceAdapterTag(), token);
1135 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
1137 UpperBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
1138 inputPortal, valuesPortal, outputPortal);
1139 DerivedAlgorithm::Schedule(kernel, arraySize);
1142 template <
typename T,
class CIn,
class CVal,
class COut,
class BinaryCompare>
1146 BinaryCompare binary_compare)
1155 auto valuesPortal = values.
PrepareForInput(DeviceAdapterTag(), token);
1156 auto outputPortal = output.
PrepareForOutput(arraySize, DeviceAdapterTag(), token);
1158 UpperBoundsKernelComparisonKernel<decltype(inputPortal),
1159 decltype(valuesPortal),
1160 decltype(outputPortal),
1162 kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
1164 DerivedAlgorithm::Schedule(kernel, arraySize);
1167 template <
class CIn,
class COut>
1173 DeviceAdapterAlgorithmGeneral<DerivedAlgorithm, DeviceAdapterTag>::UpperBounds(
1174 input, values_output, values_output);
1190 template <
typename DeviceTag>
1191 class DeviceTaskTypes
1194 template <
typename WorkletType,
typename InvocationType>
1195 static vtkm::exec::internal::TaskSingular<WorkletType, InvocationType>
MakeTask(
1196 WorkletType& worklet,
1197 InvocationType& invocation,
1201 using Task = vtkm::exec::internal::TaskSingular<WorkletType, InvocationType>;
1202 return Task(worklet, invocation, globalIndexOffset);
1205 template <
typename WorkletType,
typename InvocationType>
1206 static vtkm::exec::internal::TaskSingular<WorkletType, InvocationType>
MakeTask(
1207 WorkletType& worklet,
1208 InvocationType& invocation,
1212 using Task = vtkm::exec::internal::TaskSingular<WorkletType, InvocationType>;
1213 return Task(worklet, invocation, globalIndexOffset);
1219 #endif //vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h