10 #ifndef vtk_m_cont_tbb_internal_FunctorsTBB_h
11 #define vtk_m_cont_tbb_internal_FunctorsTBB_h
23 #include <type_traits>
25 VTKM_THIRDPARTY_PRE_INCLUDE
27 #if defined(VTKM_MSVC)
32 #pragma push_macro("__TBB_NO_IMPLICITLINKAGE")
33 #define __TBB_NO_IMPLICIT_LINKAGE 1
35 #endif // defined(VTKM_MSVC)
43 #include <tbb/blocked_range.h>
44 #include <tbb/blocked_range3d.h>
45 #include <tbb/parallel_for.h>
46 #include <tbb/parallel_reduce.h>
47 #include <tbb/parallel_scan.h>
48 #include <tbb/parallel_sort.h>
49 #include <tbb/partitioner.h>
50 #include <tbb/tick_count.h>
52 #if defined(VTKM_MSVC)
53 #pragma pop_macro("__TBB_NO_IMPLICITLINKAGE")
56 VTKM_THIRDPARTY_POST_INCLUDE
67 template <
typename ResultType,
typename Function>
68 using WrappedBinaryOperator = vtkm::cont::internal::WrappedBinaryOperator<ResultType, Function>;
73 static constexpr
vtkm::Id TBB_GRAIN_SIZE = 1024;
75 template <
typename InputPortalType,
typename OutputPortalType>
84 const OutputPortalType& outPortal,
99 template <
typename InIter,
typename OutIter>
100 void DoCopy(InIter src, InIter srcEnd, OutIter dst, std::false_type)
const
102 using InputType =
typename InputPortalType::ValueType;
103 using OutputType =
typename OutputPortalType::ValueType;
104 while (src != srcEnd)
110 *dst =
static_cast<OutputType
>(
static_cast<InputType
>(*src));
117 template <
typename InIter,
typename OutIter>
118 void DoCopy(InIter src, InIter srcEnd, OutIter dst, std::true_type)
const
120 std::copy(src, srcEnd, dst);
125 void operator()(const ::tbb::blocked_range<vtkm::Id>& range)
const
135 using InputType =
typename InputPortalType::ValueType;
136 using OutputType =
typename OutputPortalType::ValueType;
138 this->
DoCopy(inIter + this->InputOffset + range.begin(),
139 inIter + this->InputOffset + range.end(),
140 outIter + this->OutputOffset + range.begin(),
141 std::is_same<InputType, OutputType>());
145 template <
typename InputPortalType,
typename OutputPortalType>
147 const OutputPortalType& outPortal,
153 Kernel kernel(inPortal, outPortal, inOffset, outOffset);
154 ::tbb::blocked_range<vtkm::Id> range(0, numValues, TBB_GRAIN_SIZE);
155 ::tbb::parallel_for(range, kernel);
158 template <
typename InputPortalType,
159 typename StencilPortalType,
160 typename OutputPortalType,
161 typename UnaryPredicateType>
201 (this->OutputEnd - this->OutputBegin) <= (this->InputEnd - this->InputBegin));
216 const StencilPortalType& stencilPortal,
217 const OutputPortalType& outputPortal,
218 UnaryPredicateType unaryPredicate)
254 this->Ranges.
InputEnd = range.end();
261 InputIteratorsType inputIters(this->InputPortal);
262 StencilIteratorsType stencilIters(this->StencilPortal);
263 OutputIteratorsType outputIters(this->OutputPortal);
265 using InputIteratorType =
typename InputIteratorsType::IteratorType;
266 using StencilIteratorType =
typename StencilIteratorsType::IteratorType;
267 using OutputIteratorType =
typename OutputIteratorsType::IteratorType;
269 InputIteratorType inIter = inputIters.GetBegin();
270 StencilIteratorType stencilIter = stencilIters.GetBegin();
271 OutputIteratorType outIter = outputIters.GetBegin();
274 const vtkm::Id readEnd = range.end();
283 writePos = range.begin();
296 UnaryPredicateType predicate(this->UnaryPredicate);
297 for (; readPos < readEnd; ++readPos)
299 if (predicate(stencilIter[readPos]))
301 outIter[writePos] = inIter[readPos];
314 using OutputIteratorType =
typename OutputIteratorsType::IteratorType;
316 OutputIteratorsType outputIters(this->OutputPortal);
317 OutputIteratorType outIter = outputIters.GetBegin();
329 if (srcBegin != dstBegin && srcBegin != srcEnd)
333 std::copy(outIter + srcBegin, outIter + srcEnd, outIter + dstBegin);
337 this->Ranges.
OutputEnd += srcEnd - srcBegin;
343 template <
typename InputPortalType,
344 typename StencilPortalType,
345 typename OutputPortalType,
346 typename UnaryPredicateType>
348 StencilPortalType stencilPortal,
349 OutputPortalType outputPortal,
350 UnaryPredicateType unaryPredicate)
352 const vtkm::Id inputLength = inputPortal.GetNumberOfValues();
353 VTKM_ASSERT(inputLength == stencilPortal.GetNumberOfValues());
355 if (inputLength == 0)
361 inputPortal, stencilPortal, outputPortal, unaryPredicate);
362 ::tbb::blocked_range<vtkm::Id> range(0, inputLength, TBB_GRAIN_SIZE);
364 ::tbb::parallel_reduce(range, body);
368 body.
Ranges.OutputBegin == 0 && body.
Ranges.OutputEnd <= inputLength);
370 return body.
Ranges.OutputEnd;
373 template <
class InputPortalType,
class T,
class BinaryOperationType>
385 BinaryOperationType binaryOperation)
409 InputIteratorsType inputIterators(this->InputPortal);
412 typename InputIteratorsType::IteratorType inIter =
413 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
418 for (
vtkm::Id index = range.begin() + 2; index != range.end(); ++index, ++inIter)
424 if (range.begin() == 0)
441 this->FirstCall =
false;
454 template <
class InputPortalType,
typename T,
class BinaryOperationType>
455 VTKM_CONT static auto ReducePortals(InputPortalType inputPortal,
457 BinaryOperationType binaryOperation)
458 -> decltype(binaryOperation(initialValue, inputPortal.Get(0)))
460 using ResultType = decltype(binaryOperation(initialValue, inputPortal.Get(0)));
461 using WrappedBinaryOp = internal::WrappedBinaryOperator<ResultType, BinaryOperationType>;
463 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
464 ReduceBody<InputPortalType, ResultType, WrappedBinaryOp> body(
465 inputPortal, initialValue, wrappedBinaryOp);
466 vtkm::Id arrayLength = inputPortal.GetNumberOfValues();
470 ::tbb::blocked_range<vtkm::Id> range(0, arrayLength, TBB_GRAIN_SIZE);
471 ::tbb::parallel_reduce(range, body);
474 else if (arrayLength == 1)
477 return binaryOperation(initialValue, inputPortal.Get(0));
482 return static_cast<ResultType
>(initialValue);
490 template <
typename KeysInPortalType,
491 typename ValuesInPortalType,
492 typename KeysOutPortalType,
493 typename ValuesOutPortalType,
494 class BinaryOperationType>
497 using KeyType =
typename KeysInPortalType::ValueType;
498 using ValueType =
typename ValuesInPortalType::ValueType;
534 (this->OutputEnd - this->OutputBegin) <= (this->InputEnd - this->InputBegin));
547 #ifdef VTKM_DEBUG_TBB_RBK
554 const ValuesInPortalType& valuesInPortal,
555 const KeysOutPortalType& keysOutPortal,
556 const ValuesOutPortalType& valuesOutPortal,
557 BinaryOperationType binaryOperation)
563 #ifdef VTKM_DEBUG_TBB_RBK
577 #ifdef VTKM_DEBUG_TBB_RBK
588 #ifdef VTKM_DEBUG_TBB_RBK
589 ::tbb::tick_count startTime = ::tbb::tick_count::now();
590 #endif // VTKM_DEBUG_TBB_RBK
606 this->Ranges.
InputEnd = range.end();
614 KeysInIteratorsType keysInIters(this->KeysInPortal);
615 ValuesInIteratorsType valuesInIters(this->ValuesInPortal);
616 KeysOutIteratorsType keysOutIters(this->KeysOutPortal);
617 ValuesOutIteratorsType valuesOutIters(this->ValuesOutPortal);
619 using KeysInIteratorType =
typename KeysInIteratorsType::IteratorType;
620 using ValuesInIteratorType =
typename ValuesInIteratorsType::IteratorType;
621 using KeysOutIteratorType =
typename KeysOutIteratorsType::IteratorType;
622 using ValuesOutIteratorType =
typename ValuesOutIteratorsType::IteratorType;
624 KeysInIteratorType keysIn = keysInIters.GetBegin();
625 ValuesInIteratorType valuesIn = valuesInIters.GetBegin();
626 KeysOutIteratorType keysOut = keysOutIters.GetBegin();
627 ValuesOutIteratorType valuesOut = valuesOutIters.GetBegin();
630 const vtkm::Id readEnd = range.end();
639 writePos = range.begin();
653 BinaryOperationType functor(this->BinaryOperation);
654 KeyType currentKey = keysIn[readPos];
655 ValueType currentValue = valuesIn[readPos];
661 if (!firstRun && keysOut[writePos - 1] == currentKey)
667 currentValue = functor(valuesOut[writePos], currentValue);
671 if (readPos >= readEnd)
673 keysOut[writePos] = currentKey;
674 valuesOut[writePos] = currentValue;
683 while (readPos < readEnd && currentKey == keysIn[readPos])
685 currentValue = functor(currentValue, valuesIn[readPos]);
690 keysOut[writePos] = currentKey;
691 valuesOut[writePos] = currentValue;
694 if (readPos < readEnd)
696 currentKey = keysIn[readPos];
697 currentValue = valuesIn[readPos];
707 #ifdef VTKM_DEBUG_TBB_RBK
708 ::tbb::tick_count endTime = ::tbb::tick_count::now();
709 double time = (endTime - startTime).seconds();
710 this->ReduceTime += time;
711 std::ostringstream out;
712 out <<
"Reduced " << range.size() <<
" key/value pairs in " << time <<
"s. "
715 std::cerr << out.str();
725 using KeysIteratorType =
typename KeysIteratorsType::IteratorType;
726 using ValuesIteratorType =
typename ValuesIteratorsType::IteratorType;
728 #ifdef VTKM_DEBUG_TBB_RBK
729 ::tbb::tick_count startTime = ::tbb::tick_count::now();
737 KeysIteratorsType keysIters(this->KeysOutPortal);
738 ValuesIteratorsType valuesIters(this->ValuesOutPortal);
739 KeysIteratorType keys = keysIters.GetBegin();
740 ValuesIteratorType values = valuesIters.GetBegin();
749 if (keys[srcBegin] == keys[lastDstIdx])
751 values[lastDstIdx] = this->
BinaryOperation(values[lastDstIdx], values[srcBegin]);
756 if (srcBegin != dstBegin && srcBegin != srcEnd)
760 std::copy(keys + srcBegin, keys + srcEnd, keys + dstBegin);
761 std::copy(values + srcBegin, values + srcEnd, values + dstBegin);
765 this->Ranges.
OutputEnd += srcEnd - srcBegin;
768 #ifdef VTKM_DEBUG_TBB_RBK
769 ::tbb::tick_count endTime = ::tbb::tick_count::now();
770 double time = (endTime - startTime).seconds();
771 this->JoinTime += rhs.JoinTime + time;
772 std::ostringstream out;
773 out <<
"Joined " << (srcEnd - srcBegin) <<
" rhs values into body in " << time <<
"s. "
774 <<
"InRange: " << this->Ranges.
InputBegin <<
" " << this->Ranges.InputEnd <<
" "
775 <<
"OutRange: " << this->Ranges.OutputBegin <<
" " << this->Ranges.OutputEnd <<
"\n";
776 std::cerr << out.str();
782 template <
typename KeysInPortalType,
783 typename ValuesInPortalType,
784 typename KeysOutPortalType,
785 typename ValuesOutPortalType,
786 typename BinaryOperationType>
788 ValuesInPortalType valuesInPortal,
789 KeysOutPortalType keysOutPortal,
790 ValuesOutPortalType valuesOutPortal,
791 BinaryOperationType binaryOperation)
793 const vtkm::Id inputLength = keysInPortal.GetNumberOfValues();
794 VTKM_ASSERT(inputLength == valuesInPortal.GetNumberOfValues());
796 if (inputLength == 0)
801 using ValueType =
typename ValuesInPortalType::ValueType;
802 using WrappedBinaryOp = internal::WrappedBinaryOperator<ValueType, BinaryOperationType>;
803 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
810 body(keysInPortal, valuesInPortal, keysOutPortal, valuesOutPortal, wrappedBinaryOp);
811 ::tbb::blocked_range<vtkm::Id> range(0, inputLength, TBB_GRAIN_SIZE);
813 #ifdef VTKM_DEBUG_TBB_RBK
814 std::cerr <<
"\n\nTBB ReduceByKey:\n";
817 ::tbb::parallel_reduce(range, body);
819 #ifdef VTKM_DEBUG_TBB_RBK
820 std::cerr <<
"Total reduce time: " << body.ReduceTime <<
"s\n";
821 std::cerr <<
"Total join time: " << body.JoinTime <<
"s\n";
822 std::cerr <<
"\nend\n";
825 body.Ranges.AssertSane();
826 VTKM_ASSERT(body.Ranges.InputBegin == 0 && body.Ranges.InputEnd == inputLength &&
827 body.Ranges.OutputBegin == 0 && body.Ranges.OutputEnd <= inputLength);
829 return body.Ranges.OutputEnd;
832 #ifdef VTKM_DEBUG_TBB_RBK
833 #undef VTKM_DEBUG_TBB_RBK
836 template <
class InputPortalType,
class OutputPortalType,
class BinaryOperationType>
839 using ValueType =
typename std::remove_reference<typename OutputPortalType::ValueType>::type;
848 const OutputPortalType& outputPortal,
849 BinaryOperationType binaryOperation)
870 void operator()(const ::tbb::blocked_range<vtkm::Id>& range, ::tbb::pre_scan_tag)
873 InputIteratorsType inputIterators(this->InputPortal);
876 typename InputIteratorsType::IteratorType inIter =
877 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
879 this->FirstCall =
false;
880 for (
vtkm::Id index = range.begin() + 1; index != range.end(); ++index, ++inIter)
889 void operator()(const ::tbb::blocked_range<vtkm::Id>& range, ::tbb::final_scan_tag)
894 InputIteratorsType inputIterators(this->InputPortal);
895 OutputIteratorsType outputIterators(this->OutputPortal);
898 typename InputIteratorsType::IteratorType inIter =
899 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
900 typename OutputIteratorsType::IteratorType outIter =
901 outputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
903 this->FirstCall =
false;
905 for (
vtkm::Id index = range.begin() + 1; index != range.end(); ++index, ++inIter, ++outIter)
924 template <
class InputPortalType,
class OutputPortalType,
class BinaryOperationType>
927 using ValueType =
typename std::remove_reference<typename OutputPortalType::ValueType>::type;
937 const OutputPortalType& outputPortal,
938 BinaryOperationType binaryOperation,
960 void operator()(const ::tbb::blocked_range<vtkm::Id>& range, ::tbb::pre_scan_tag)
963 InputIteratorsType inputIterators(this->InputPortal);
966 typename InputIteratorsType::IteratorType iter =
967 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
971 if (!(this->FirstCall && range.begin() > 0))
975 for (
vtkm::Id index = range.begin() + 1; index != range.end(); ++index, ++iter)
980 this->FirstCall =
false;
985 void operator()(const ::tbb::blocked_range<vtkm::Id>& range, ::tbb::final_scan_tag)
990 InputIteratorsType inputIterators(this->InputPortal);
991 OutputIteratorsType outputIterators(this->OutputPortal);
994 typename InputIteratorsType::IteratorType inIter =
995 inputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
996 typename OutputIteratorsType::IteratorType outIter =
997 outputIterators.GetBegin() +
static_cast<std::ptrdiff_t
>(range.begin());
1000 for (
vtkm::Id index = range.begin(); index != range.end(); ++index, ++inIter, ++outIter)
1009 this->FirstCall =
false;
1020 if (!left.
FirstCall && !this->FirstCall)
1032 template <
class InputPortalType,
class OutputPortalType,
class BinaryOperationType>
1033 VTKM_CONT static typename std::remove_reference<typename OutputPortalType::ValueType>::type
1034 ScanInclusivePortals(InputPortalType inputPortal,
1035 OutputPortalType outputPortal,
1036 BinaryOperationType binaryOperation)
1038 using ValueType =
typename std::remove_reference<typename OutputPortalType::ValueType>::type;
1040 using WrappedBinaryOp = internal::WrappedBinaryOperator<ValueType, BinaryOperationType>;
1042 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
1043 ScanInclusiveBody<InputPortalType, OutputPortalType, WrappedBinaryOp> body(
1044 inputPortal, outputPortal, wrappedBinaryOp);
1045 vtkm::Id arrayLength = inputPortal.GetNumberOfValues();
1047 ::tbb::blocked_range<vtkm::Id> range(0, arrayLength, TBB_GRAIN_SIZE);
1048 ::tbb::parallel_scan(range, body);
1053 template <
class InputPortalType,
class OutputPortalType,
class BinaryOperationType>
1054 VTKM_CONT static typename std::remove_reference<typename OutputPortalType::ValueType>::type
1055 ScanExclusivePortals(
1056 InputPortalType inputPortal,
1057 OutputPortalType outputPortal,
1058 BinaryOperationType binaryOperation,
1059 typename std::remove_reference<typename OutputPortalType::ValueType>::type initialValue)
1061 using ValueType =
typename std::remove_reference<typename OutputPortalType::ValueType>::type;
1063 using WrappedBinaryOp = internal::WrappedBinaryOperator<ValueType, BinaryOperationType>;
1065 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
1066 ScanExclusiveBody<InputPortalType, OutputPortalType, WrappedBinaryOp> body(
1067 inputPortal, outputPortal, wrappedBinaryOp, initialValue);
1068 vtkm::Id arrayLength = inputPortal.GetNumberOfValues();
1070 ::tbb::blocked_range<vtkm::Id> range(0, arrayLength, TBB_GRAIN_SIZE);
1071 ::tbb::parallel_scan(range, body);
1078 template <
typename InputPortalType,
typename IndexPortalType,
typename OutputPortalType>
1083 IndexPortalType indexPortal,
1084 OutputPortalType outputPortal)
1092 void operator()(const ::tbb::blocked_range<vtkm::Id>& range)
const
1102 VTKM_VECTORIZATION_PRE_LOOP
1103 for (
vtkm::Id i = range.begin(); i < range.end(); i++)
1105 VTKM_VECTORIZATION_IN_LOOP
1115 this->
ErrorMessage.RaiseError(
"Unexpected error in execution environment.");
1127 template <
typename InputPortalType,
typename IndexPortalType,
typename OutputPortalType>
1128 VTKM_CONT static void ScatterPortal(InputPortalType inputPortal,
1129 IndexPortalType indexPortal,
1130 OutputPortalType outputPortal)
1132 const vtkm::Id size = inputPortal.GetNumberOfValues();
1133 VTKM_ASSERT(size == indexPortal.GetNumberOfValues());
1136 inputPortal, indexPortal, outputPortal);
1138 ::tbb::blocked_range<vtkm::Id> range(0, size, TBB_GRAIN_SIZE);
1139 ::tbb::parallel_for(range, scatter);
1142 template <
typename PortalType,
typename BinaryOperationType>
1182 this->OutputEnd <= this->
InputEnd);
1184 (this->OutputEnd - this->OutputBegin) <= (this->InputEnd - this->InputBegin));
1197 UniqueBody(
const PortalType& portal, BinaryOperationType binaryOperation)
1229 this->Ranges.
InputEnd = range.end();
1233 using IteratorType =
typename IteratorsType::IteratorType;
1235 IteratorsType iters(this->Portal);
1236 IteratorType data = iters.GetBegin();
1239 const vtkm::Id readEnd = range.end();
1248 writePos = range.begin();
1262 BinaryOperationType functor(this->BinaryOperation);
1270 if (!firstRun && functor(data[writePos - 1], current))
1278 current = data[writePos];
1282 if (readPos >= readEnd)
1284 data[writePos] = current;
1294 while (readPos < readEnd && functor(current, data[readPos]))
1301 data[writePos] = current;
1305 if (readPos < readEnd)
1307 current = data[readPos];
1324 using IteratorType =
typename IteratorsType::IteratorType;
1332 IteratorsType iters(this->Portal);
1333 IteratorType data = iters.GetBegin();
1334 BinaryOperationType functor(this->BinaryOperation);
1343 if (functor(data[srcBegin], data[lastDstIdx]))
1349 if (srcBegin != dstBegin && srcBegin != srcEnd)
1353 std::copy(data + srcBegin, data + srcEnd, data + dstBegin);
1357 this->Ranges.
OutputEnd += srcEnd - srcBegin;
1363 template <
typename PortalType,
typename BinaryOperationType>
1366 const vtkm::Id inputLength = portal.GetNumberOfValues();
1367 if (inputLength == 0)
1372 using WrappedBinaryOp = internal::WrappedBinaryOperator<bool, BinaryOperationType>;
1373 WrappedBinaryOp wrappedBinaryOp(binaryOperation);
1376 ::tbb::blocked_range<vtkm::Id> range(0, inputLength, TBB_GRAIN_SIZE);
1378 ::tbb::parallel_reduce(range, body);
1380 body.
Ranges.AssertSane();
1382 body.
Ranges.OutputBegin == 0 && body.
Ranges.OutputEnd <= inputLength);
1384 return body.
Ranges.OutputEnd;
1389 #endif //vtk_m_cont_tbb_internal_FunctorsTBB_h