81 #ifndef vtk_m_cont_internal_ParallelRadixSort_h
82 #define vtk_m_cont_internal_ParallelRadixSort_h
98 VTKM_THIRDPARTY_PRE_INCLUDE
114 inline size_t GetMaxThreads(
size_t num_bytes,
size_t available_cores)
116 const double CORES_PER_BYTE =
117 double(available_cores - 1) / double(BYTES_FOR_MAX_PARALLELISM - MIN_BYTES_FOR_PARALLEL);
118 const double Y_INTERCEPT = 1.0 - CORES_PER_BYTE * MIN_BYTES_FOR_PARALLEL;
120 const size_t num_cores = (size_t)(CORES_PER_BYTE *
double(num_bytes) + Y_INTERCEPT);
125 if (num_cores > available_cores)
127 return available_cores;
136 const size_t kOutBufferSize = 32;
139 template <
typename PlainType,
140 typename UnsignedType,
141 typename CompareType,
142 typename ValueManager,
144 struct ParallelRadixCompareInternal
146 inline static void reverse(UnsignedType& t) { (void)t; }
150 template <
typename PlainType,
typename Un
signedType,
typename ValueManager,
unsigned int Base>
151 struct ParallelRadixCompareInternal<PlainType,
153 std::greater<PlainType>,
157 inline static void reverse(UnsignedType& t) { t = ((1 << Base) - 1) - t; }
161 template <
typename PlainType,
162 typename CompareType,
163 typename UnsignedType,
165 typename ValueManager,
166 typename ThreaderType,
168 class ParallelRadixSortInternal
171 using CompareInternal =
172 ParallelRadixCompareInternal<PlainType, UnsignedType, CompareType, ValueManager, Base>;
174 ParallelRadixSortInternal();
175 ~ParallelRadixSortInternal();
177 void Init(PlainType* data,
size_t num_elems,
const ThreaderType& threader);
179 PlainType* Sort(PlainType* data, ValueManager* value_manager);
181 static void InitAndSort(PlainType* data,
183 const ThreaderType& threader,
184 ValueManager* value_manager);
187 CompareInternal compare_internal_;
193 UnsignedType*** out_buf_;
196 size_t *pos_bgn_, *pos_end_;
197 ValueManager* value_manager_;
198 ThreaderType threader_;
202 UnsignedType* SortInternal(UnsignedType* data, ValueManager* value_manager);
205 void ComputeRanges();
209 void ComputeHistogram(
unsigned int b, UnsignedType* src);
213 void Scatter(
unsigned int b, UnsignedType* src, UnsignedType* dst);
216 template <
typename PlainType,
217 typename CompareType,
218 typename UnsignedType,
220 typename ValueManager,
221 typename ThreaderType,
223 ParallelRadixSortInternal<PlainType,
229 Base>::ParallelRadixSortInternal()
239 assert(
sizeof(PlainType) ==
sizeof(UnsignedType));
242 template <
typename PlainType,
243 typename CompareType,
244 typename UnsignedType,
246 typename ValueManager,
247 typename ThreaderType,
249 ParallelRadixSortInternal<PlainType,
255 Base>::~ParallelRadixSortInternal()
260 template <
typename PlainType,
261 typename CompareType,
262 typename UnsignedType,
264 typename ValueManager,
265 typename ThreaderType,
267 void ParallelRadixSortInternal<PlainType,
278 for (
size_t i = 0; i < num_threads_; ++i)
283 for (
size_t i = 0; i < num_threads_; ++i)
285 for (
size_t j = 0; j < 1 << Base; ++j)
287 delete[] out_buf_[i][j];
289 delete[] out_buf_n_[i];
290 delete[] out_buf_[i];
299 pos_bgn_ = pos_end_ = NULL;
305 template <
typename PlainType,
306 typename CompareType,
307 typename UnsignedType,
309 typename ValueManager,
310 typename ThreaderType,
312 void ParallelRadixSortInternal<PlainType,
318 Base>::Init(PlainType* data,
320 const ThreaderType& threader)
325 threader_ = threader;
327 num_elems_ = num_elems;
330 utility::GetMaxThreads(num_elems_ *
sizeof(PlainType), threader_.GetAvailableCores());
332 tmp_ =
new UnsignedType[num_elems_];
333 histo_ =
new size_t*[num_threads_];
334 for (
size_t i = 0; i < num_threads_; ++i)
336 histo_[i] =
new size_t[1 << Base];
339 out_buf_ =
new UnsignedType**[num_threads_];
340 out_buf_n_ =
new size_t*[num_threads_];
341 for (
size_t i = 0; i < num_threads_; ++i)
343 out_buf_[i] =
new UnsignedType*[1 << Base];
344 out_buf_n_[i] =
new size_t[1 << Base];
345 for (
size_t j = 0; j < 1 << Base; ++j)
347 out_buf_[i][j] =
new UnsignedType[kOutBufferSize];
351 pos_bgn_ =
new size_t[num_threads_];
352 pos_end_ =
new size_t[num_threads_];
355 template <
typename PlainType,
356 typename CompareType,
357 typename UnsignedType,
359 typename ValueManager,
360 typename ThreaderType,
362 PlainType* ParallelRadixSortInternal<PlainType,
368 Base>::Sort(PlainType* data, ValueManager* value_manager)
370 UnsignedType* src =
reinterpret_cast<UnsignedType*
>(data);
371 UnsignedType* res = SortInternal(src, value_manager);
372 return reinterpret_cast<PlainType*
>(res);
375 template <
typename PlainType,
376 typename CompareType,
377 typename UnsignedType,
379 typename ValueManager,
380 typename ThreaderType,
382 void ParallelRadixSortInternal<PlainType,
388 Base>::InitAndSort(PlainType* data,
390 const ThreaderType& threader,
391 ValueManager* value_manager)
393 ParallelRadixSortInternal prs;
394 prs.Init(data, num_elems, threader);
395 const PlainType* res = prs.Sort(data, value_manager);
398 for (
size_t i = 0; i < num_elems; ++i)
403 template <
typename PlainType,
404 typename CompareType,
405 typename UnsignedType,
407 typename ValueManager,
408 typename ThreaderType,
410 UnsignedType* ParallelRadixSortInternal<PlainType,
416 Base>::SortInternal(UnsignedType* data,
417 ValueManager* value_manager)
420 value_manager_ = value_manager;
426 const size_t bits = CHAR_BIT *
sizeof(UnsignedType);
427 UnsignedType *src = data, *dst = tmp_;
428 for (
unsigned int b = 0; b < bits; b += Base)
430 ComputeHistogram(b, src);
431 Scatter(b, src, dst);
434 value_manager->Next();
440 template <
typename PlainType,
441 typename CompareType,
442 typename UnsignedType,
444 typename ValueManager,
445 typename ThreaderType,
447 void ParallelRadixSortInternal<PlainType,
453 Base>::ComputeRanges()
456 for (
size_t i = 0; i < num_threads_ - 1; ++i)
458 const size_t t = (num_elems_ - pos_bgn_[i]) / (num_threads_ - i);
459 pos_bgn_[i + 1] = pos_end_[i] = pos_bgn_[i] + t;
461 pos_end_[num_threads_ - 1] = num_elems_;
464 template <
typename PlainType,
465 typename UnsignedType,
469 typename ThreaderType>
472 RunTask(
size_t binary_tree_height,
473 size_t binary_tree_position,
477 const ThreaderType& threader)
478 : binary_tree_height_(binary_tree_height)
479 , binary_tree_position_(binary_tree_position)
481 , num_elems_(num_elems)
482 , num_threads_(num_threads)
483 , threader_(threader)
487 template <
typename ThreaderData =
void*>
488 void operator()(ThreaderData tData =
nullptr)
const
490 size_t num_nodes_at_current_height = (size_t)pow(2, (
double)binary_tree_height_);
491 if (num_threads_ <= num_nodes_at_current_height)
493 const size_t my_id = binary_tree_position_ - num_nodes_at_current_height;
494 if (my_id < num_threads_)
501 RunTask left(binary_tree_height_ + 1,
502 2 * binary_tree_position_,
507 RunTask right(binary_tree_height_ + 1,
508 2 * binary_tree_position_ + 1,
513 threader_.RunChildTasks(tData, left, right);
517 size_t binary_tree_height_;
518 size_t binary_tree_position_;
522 ThreaderType threader_;
525 template <
typename PlainType,
526 typename CompareType,
527 typename UnsignedType,
529 typename ValueManager,
530 typename ThreaderType,
532 void ParallelRadixSortInternal<PlainType,
538 Base>::ComputeHistogram(
unsigned int b, UnsignedType* src)
542 auto lambda = [=](
const size_t my_id) {
543 const size_t my_bgn = pos_bgn_[my_id];
544 const size_t my_end = pos_end_[my_id];
545 size_t* my_histo = histo_[my_id];
547 memset(my_histo, 0,
sizeof(
size_t) * (1 << Base));
548 for (
size_t i = my_bgn; i < my_end; ++i)
550 const UnsignedType s = Encoder::encode(src[i]);
551 UnsignedType t = (s >> b) & ((1 << Base) - 1);
552 compare_internal_.reverse(t);
558 RunTask<PlainType, UnsignedType, Encoder, Base, std::function<void(
size_t)>, ThreaderType>;
560 RunTaskType root(0, 1, lambda, num_elems_, num_threads_, threader_);
561 this->threader_.RunParentTask(root);
565 for (
size_t i = 0; i < 1 << Base; ++i)
567 for (
size_t j = 0; j < num_threads_; ++j)
569 const size_t t = s + histo_[j][i];
576 template <
typename PlainType,
577 typename CompareType,
578 typename UnsignedType,
580 typename ValueManager,
581 typename ThreaderType,
583 void ParallelRadixSortInternal<PlainType,
589 Base>::Scatter(
unsigned int b, UnsignedType* src, UnsignedType* dst)
592 auto lambda = [=](
const size_t my_id) {
593 const size_t my_bgn = pos_bgn_[my_id];
594 const size_t my_end = pos_end_[my_id];
595 size_t* my_histo = histo_[my_id];
596 UnsignedType** my_buf = out_buf_[my_id];
597 size_t* my_buf_n = out_buf_n_[my_id];
599 memset(my_buf_n, 0,
sizeof(
size_t) * (1 << Base));
600 for (
size_t i = my_bgn; i < my_end; ++i)
602 const UnsignedType s = Encoder::encode(src[i]);
603 UnsignedType t = (s >> b) & ((1 << Base) - 1);
604 compare_internal_.reverse(t);
605 my_buf[t][my_buf_n[t]] = src[i];
606 value_manager_->Push(my_id, t, my_buf_n[t], i);
609 if (my_buf_n[t] == kOutBufferSize)
611 size_t p = my_histo[t];
612 for (
size_t j = 0; j < kOutBufferSize; ++j)
615 dst[tp] = my_buf[t][j];
617 value_manager_->Flush(my_id, t, kOutBufferSize, my_histo[t]);
619 my_histo[t] += kOutBufferSize;
625 for (
size_t i = 0; i < 1 << Base; ++i)
627 size_t p = my_histo[i];
628 for (
size_t j = 0; j < my_buf_n[i]; ++j)
631 dst[tp] = my_buf[i][j];
633 value_manager_->Flush(my_id, i, my_buf_n[i], my_histo[i]);
638 RunTask<PlainType, UnsignedType, Encoder, Base, std::function<void(
size_t)>, ThreaderType>;
639 RunTaskType root(0, 1, lambda, num_elems_, num_threads_, threader_);
640 this->threader_.RunParentTask(root);
652 class EncoderUnsigned
655 template <
typename Un
signedType>
656 inline static UnsignedType encode(UnsignedType x)
665 template <
typename Un
signedType>
666 inline static UnsignedType encode(UnsignedType x)
668 return x ^ (UnsignedType(1) << (CHAR_BIT *
sizeof(UnsignedType) - 1));
675 template <
typename Un
signedType>
676 inline static UnsignedType encode(UnsignedType x)
678 static const size_t bits = CHAR_BIT *
sizeof(UnsignedType);
679 const UnsignedType a = x >> (bits - 1);
680 const UnsignedType b = (-
static_cast<int>(a)) | (UnsignedType(1) << (bits - 1));
688 namespace value_manager
690 class DummyValueManager
693 inline void Push(
int thread,
size_t bucket,
size_t num,
size_t from_pos)
701 inline void Flush(
int thread,
size_t bucket,
size_t num,
size_t to_pos)
712 template <
typename PlainType,
typename ValueType,
int Base>
713 class PairValueManager
728 ~PairValueManager() { DeleteAll(); }
730 void Init(
size_t max_elems,
size_t available_threads);
732 void Start(ValueType* original,
size_t num_elems)
734 assert(num_elems <= max_elems_);
735 src_ = original_ = original;
739 inline void Push(
int thread,
size_t bucket,
size_t num,
size_t from_pos)
741 out_buf_[thread][bucket][num] = src_[from_pos];
744 inline void Flush(
int thread,
size_t bucket,
size_t num,
size_t to_pos)
746 for (
size_t i = 0; i < num; ++i)
748 dst_[to_pos++] = out_buf_[thread][bucket][i];
752 void Next() { std::swap(src_, dst_); }
754 ValueType* GetResult() {
return src_; }
760 static constexpr
size_t kOutBufferSize = internal::kOutBufferSize;
761 ValueType *original_, *tmp_;
762 ValueType *src_, *dst_;
763 ValueType*** out_buf_;
764 vtkm::UInt64 tmp_size;
769 template <
typename PlainType,
typename ValueType,
int Base>
770 void PairValueManager<PlainType, ValueType, Base>::Init(
size_t max_elems,
size_t available_cores)
774 max_elems_ = max_elems;
775 max_threads_ = utility::GetMaxThreads(max_elems_ *
sizeof(PlainType), available_cores);
778 tmp_size = max_elems *
sizeof(ValueType);
780 "Allocating working memory for radix sort-by-key: %s.",
782 tmp_ =
new ValueType[max_elems];
785 out_buf_ =
new ValueType**[max_threads_];
786 for (
int i = 0; i < max_threads_; ++i)
788 out_buf_[i] =
new ValueType*[1 << Base];
789 for (
size_t j = 0; j < 1 << Base; ++j)
791 out_buf_[i][j] =
new ValueType[kOutBufferSize];
796 template <
typename PlainType,
typename ValueType,
int Base>
797 void PairValueManager<PlainType, ValueType, Base>::DeleteAll()
801 "Freeing working memory for radix sort-by-key: %s.",
808 for (
int i = 0; i < max_threads_; ++i)
810 for (
size_t j = 0; j < 1 << Base; ++j)
812 delete[] out_buf_[i][j];
814 delete[] out_buf_[i];
825 template <
typename ThreaderType,
827 typename CompareType,
828 typename UnsignedType = PlainType,
829 typename Encoder = encoder::EncoderDummy,
830 unsigned int Base = 8>
833 using DummyValueManager = value_manager::DummyValueManager;
834 using Internal = internal::ParallelRadixSortInternal<PlainType,
843 void InitAndSort(PlainType* data,
845 const ThreaderType& threader,
846 const CompareType& comp)
849 DummyValueManager dvm;
850 Internal::InitAndSort(data, num_elems, threader, &dvm);
855 template <
typename ThreaderType,
858 typename CompareType,
859 typename UnsignedType = PlainType,
860 typename Encoder = encoder::EncoderDummy,
864 using ValueManager = value_manager::PairValueManager<PlainType, ValueType, Base>;
865 using Internal = internal::ParallelRadixSortInternal<PlainType,
874 void InitAndSort(PlainType* keys,
877 const ThreaderType& threader,
878 const CompareType& comp)
882 vm.Init(num_elems, threader.GetAvailableCores());
883 vm.Start(vals, num_elems);
884 Internal::InitAndSort(keys, num_elems, threader, &vm);
885 ValueType* res_vals = vm.GetResult();
886 if (res_vals != vals)
888 for (
size_t i = 0; i < num_elems; ++i)
890 vals[i] = res_vals[i];
898 #define KEY_SORT_CASE(plain_type, compare_type, unsigned_type, encoder_type) \
899 template <typename ThreaderType> \
900 class KeySort<ThreaderType, plain_type, compare_type> \
901 : public KeySort<ThreaderType, \
905 encoder::Encoder##encoder_type> \
908 template <typename V, typename ThreaderType> \
909 class PairSort<ThreaderType, plain_type, V, compare_type> \
910 : public PairSort<ThreaderType, \
915 encoder::Encoder##encoder_type> \
920 KEY_SORT_CASE(
unsigned int, std::less<unsigned int>,
unsigned int, Unsigned);
921 KEY_SORT_CASE(
unsigned int, std::greater<unsigned int>,
unsigned int, Unsigned);
922 KEY_SORT_CASE(
unsigned short int, std::less<unsigned short int>,
unsigned short int, Unsigned);
923 KEY_SORT_CASE(
unsigned short int, std::greater<unsigned short int>,
unsigned short int, Unsigned);
924 KEY_SORT_CASE(
unsigned long int, std::less<unsigned long int>,
unsigned long int, Unsigned);
925 KEY_SORT_CASE(
unsigned long int, std::greater<unsigned long int>,
unsigned long int, Unsigned);
927 std::less<unsigned long long int>,
928 unsigned long long int,
931 std::greater<unsigned long long int>,
932 unsigned long long int,
936 KEY_SORT_CASE(
unsigned char, std::less<unsigned char>,
unsigned char, Unsigned);
937 KEY_SORT_CASE(
unsigned char, std::greater<unsigned char>,
unsigned char, Unsigned);
938 KEY_SORT_CASE(char16_t, std::less<char16_t>, uint16_t, Unsigned);
939 KEY_SORT_CASE(char16_t, std::greater<char16_t>, uint16_t, Unsigned);
940 KEY_SORT_CASE(char32_t, std::less<char32_t>, uint32_t, Unsigned);
941 KEY_SORT_CASE(char32_t, std::greater<char32_t>, uint32_t, Unsigned);
942 KEY_SORT_CASE(
wchar_t, std::less<wchar_t>, uint32_t, Unsigned);
943 KEY_SORT_CASE(
wchar_t, std::greater<wchar_t>, uint32_t, Unsigned);
947 KEY_SORT_CASE(
char, std::greater<char>,
unsigned char, Signed);
948 KEY_SORT_CASE(
short, std::less<short>,
unsigned short, Signed);
949 KEY_SORT_CASE(
short, std::greater<short>,
unsigned short, Signed);
953 KEY_SORT_CASE(
long, std::greater<long>,
unsigned long, Signed);
954 KEY_SORT_CASE(
long long, std::less<long long>,
unsigned long long, Signed);
955 KEY_SORT_CASE(
long long, std::greater<long long>,
unsigned long long, Signed);
958 KEY_SORT_CASE(
signed char, std::less<signed char>,
unsigned char, Signed);
959 KEY_SORT_CASE(
signed char, std::greater<signed char>,
unsigned char, Signed);
963 KEY_SORT_CASE(
float, std::greater<float>, uint32_t, Decimal);
965 KEY_SORT_CASE(
double, std::greater<double>, uint64_t, Decimal);
969 template <
typename T,
typename CompareType>
970 struct run_kx_radix_sort_keys
972 static void run(T* data,
size_t num_elems,
const CompareType& comp)
974 std::sort(data, data + num_elems, comp);
978 #define KX_SORT_KEYS(key_type) \
980 struct run_kx_radix_sort_keys<key_type, std::less<key_type>> \
982 static void run(key_type* data, size_t num_elems, const std::less<key_type>& comp) \
985 kx::radix_sort(data, data + num_elems); \
1000 template <
typename T,
typename CompareType>
1001 bool use_serial_sort_keys(T* data,
size_t num_elems,
const CompareType& comp)
1003 size_t total_bytes = (num_elems) *
sizeof(T);
1004 if (total_bytes < MIN_BYTES_FOR_PARALLEL)
1006 run_kx_radix_sort_keys<T, CompareType>::run(data, num_elems, comp);
1013 #define VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(threader_type, key_type) \
1014 VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
1015 key_type* keys, vtkm::Id* vals, size_t num_elems, const std::greater<key_type>& comp) \
1017 using namespace vtkm::cont::internal::radix; \
1018 PairSort<threader_type, key_type, vtkm::Id, std::greater<key_type>> ps; \
1019 ps.InitAndSort(keys, vals, num_elems, threader_type(), comp); \
1021 VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
1022 key_type* keys, vtkm::Id* vals, size_t num_elems, const std::less<key_type>& comp) \
1024 using namespace vtkm::cont::internal::radix; \
1025 PairSort<threader_type, key_type, vtkm::Id, std::less<key_type>> ps; \
1026 ps.InitAndSort(keys, vals, num_elems, threader_type(), comp); \
1028 VTKM_CONT_EXPORT void parallel_radix_sort( \
1029 key_type* data, size_t num_elems, const std::greater<key_type>& comp) \
1031 using namespace vtkm::cont::internal::radix; \
1032 if (!use_serial_sort_keys(data, num_elems, comp)) \
1034 KeySort<threader_type, key_type, std::greater<key_type>> ks; \
1035 ks.InitAndSort(data, num_elems, threader_type(), comp); \
1038 VTKM_CONT_EXPORT void parallel_radix_sort( \
1039 key_type* data, size_t num_elems, const std::less<key_type>& comp) \
1041 using namespace vtkm::cont::internal::radix; \
1042 if (!use_serial_sort_keys(data, num_elems, comp)) \
1044 KeySort<threader_type, key_type, std::less<key_type>> ks; \
1045 ks.InitAndSort(data, num_elems, threader_type(), comp); \
1049 #define VTKM_INSTANTIATE_RADIX_SORT_FOR_THREADER(ThreaderType) \
1050 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, short int) \
1051 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned short int) \
1052 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, int) \
1053 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned int) \
1054 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, long int) \
1055 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned long int) \
1056 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, long long int) \
1057 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned long long int) \
1058 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, unsigned char) \
1059 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, signed char) \
1060 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, char) \
1061 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, char16_t) \
1062 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, char32_t) \
1063 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, wchar_t) \
1064 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, float) \
1065 VTKM_INTERNAL_RADIX_SORT_INSTANTIATE(ThreaderType, double)
1067 VTKM_THIRDPARTY_POST_INCLUDE
1073 #endif // vtk_m_cont_internal_ParallelRadixSort_h