#:include "common.fypp" ! Specify kinds/types for the input array in select and arg_select #:set ARRAY_KINDS_TYPES = INT_KINDS_TYPES + REAL_KINDS_TYPES ! The index arrays are of all INT_KINDS_TYPES module stdlib_selection !! Quickly find the k-th smallest value of an array, or the index of the k-th smallest value. !! ([Specification](../page/specs/stdlib_selection.html)) ! ! This code was modified from the "Coretran" implementation "quickSelect" by ! Leon Foks, https://github.com/leonfoks/coretran/tree/HEAD/src/sorting ! ! Leon Foks gave permission to release this code under stdlib's MIT license. ! (https://github.com/fortran-lang/stdlib/pull/500#commitcomment-57418593) ! use stdlib_kinds implicit none private public :: select, arg_select interface select !! version: experimental !! ([Specification](..//page/specs/stdlib_selection.html#select-find-the-k-th-smallest-value-in-an-input-array)) #:for arraykind, arraytype in ARRAY_KINDS_TYPES #:for intkind, inttype in INT_KINDS_TYPES #:set name = rname("select", 1, arraytype, arraykind, intkind) module procedure ${name}$ #:endfor #:endfor end interface interface arg_select !! version: experimental !! ([Specification](..//page/specs/stdlib_selection.html#arg_select-find-the-index-of-the-k-th-smallest-value-in-an-input-array)) #:for arraykind, arraytype in ARRAY_KINDS_TYPES #:for intkind, inttype in INT_KINDS_TYPES #:set name = rname("arg_select", 1, arraytype, arraykind, intkind) module procedure ${name}$ #:endfor #:endfor end interface contains #:for arraykind, arraytype in ARRAY_KINDS_TYPES #:for intkind, inttype in INT_KINDS_TYPES #:set name = rname("select", 1, arraytype, arraykind, intkind) subroutine ${name}$(a, k, kth_smallest, left, right) !! select - select the k-th smallest entry in a(:). !! !! Partly derived from the "Coretran" implementation of !! quickSelect by Leon Foks, https://github.com/leonfoks/coretran !! ${arraytype}$, intent(inout) :: a(:) !! Array in which we seek the k-th smallest entry. !! On output it will be partially sorted such that !! `all(a(1:(k-1)) <= a(k)) .and. all(a(k) <= a((k+1):size(a)))` !! is true. ${inttype}$, intent(in) :: k !! We want the k-th smallest entry. E.G. `k=1` leads to !! `kth_smallest=min(a)`, and `k=size(a)` leads to !! `kth_smallest=max(a)` ${arraytype}$, intent(out) :: kth_smallest !! On output contains the k-th smallest value of `a(:)` ${inttype}$, intent(in), optional :: left, right !! If we know that: !! the k-th smallest entry of `a` is in `a(left:right)` !! and also that: !! `maxval(a(1:(left-1))) <= minval(a(left:right))` !! and: !! `maxval(a(left:right))) <= minval(a((right+1):size(a)))` !! then one or both bounds can be specified to narrow the search. !! The constraints are available if we have previously called the !! subroutine with different `k` (because of how `a(:)` becomes !! partially sorted, see documentation for `a(:)`). ${inttype}$ :: l, r, mid, iPivot integer, parameter :: ip = ${intkind}$ l = 1_ip if(present(left)) l = left r = size(a, kind=ip) if(present(right)) r = right if(l > r .or. l < 1_ip .or. r > size(a, kind=ip) & .or. k < l .or. k > r & !i.e. if k is not in the interval [l; r] ) then error stop "select must have 1 <= left <= k <= right <= size(a)"; end if searchk: do mid = l + ((r-l)/2_ip) ! Avoid (l+r)/2 which can cause overflow call medianOf3(a, l, mid, r) call swap(a(l), a(mid)) call partition(a, l, r, iPivot) if (iPivot < k) then l = iPivot + 1_ip elseif (iPivot > k) then r = iPivot - 1_ip elseif (iPivot == k) then kth_smallest = a(k) return end if end do searchk contains pure subroutine swap(a, b) ${arraytype}$, intent(inout) :: a, b ${arraytype}$ :: tmp tmp = a; a = b; b = tmp end subroutine pure subroutine medianOf3(a, left, mid, right) ${arraytype}$, intent(inout) :: a(:) ${inttype}$, intent(in) :: left, mid, right if(a(right) < a(left)) call swap(a(right), a(left)) if(a(mid) < a(left)) call swap(a(mid) , a(left)) if(a(right) < a(mid) ) call swap(a(mid) , a(right)) end subroutine pure subroutine partition(array,left,right,iPivot) ${arraytype}$, intent(inout) :: array(:) ${inttype}$, intent(in) :: left, right ${inttype}$, intent(out) :: iPivot ${inttype}$ :: lo,hi ${arraytype}$ :: pivot pivot = array(left) lo = left hi=right do while (lo <= hi) do while (array(hi) > pivot) hi=hi-1_ip end do inner_lohi: do while (lo <= hi ) if(array(lo) > pivot) exit inner_lohi lo=lo+1_ip end do inner_lohi if (lo <= hi) then call swap(array(lo),array(hi)) lo=lo+1_ip hi=hi-1_ip end if end do call swap(array(left),array(hi)) iPivot=hi end subroutine end subroutine #:endfor #:endfor #:for arraykind, arraytype in ARRAY_KINDS_TYPES #:for intkind, inttype in INT_KINDS_TYPES #:set name = rname("arg_select", 1, arraytype, arraykind, intkind) subroutine ${name}$(a, indx, k, kth_smallest, left, right) !! arg_select - find the index of the k-th smallest entry in `a(:)` !! !! Partly derived from the "Coretran" implementation of !! quickSelect by Leon Foks, https://github.com/leonfoks/coretran !! ${arraytype}$, intent(in) :: a(:) !! Array in which we seek the k-th smallest entry. ${inttype}$, intent(inout) :: indx(:) !! Array of indices into `a(:)`. Must contain each integer !! from `1:size(a)` exactly once. On output it will be partially !! sorted such that !! `all( a(indx(1:(k-1)))) <= a(indx(k)) ) .AND. !! all( a(indx(k)) <= a(indx( (k+1):size(a) )) )`. ${inttype}$, intent(in) :: k !! We want index of the k-th smallest entry. E.G. `k=1` leads to !! `a(kth_smallest) = min(a)`, and `k=size(a)` leads to !! `a(kth_smallest) = max(a)` ${inttype}$, intent(out) :: kth_smallest !! On output contains the index with the k-th smallest value of `a(:)` ${inttype}$, intent(in), optional :: left, right !! If we know that: !! the k-th smallest entry of `a` is in `a(indx(left:right))` !! and also that: !! `maxval(a(indx(1:(left-1)))) <= minval(a(indx(left:right)))` !! and: !! `maxval(a(indx(left:right))) <= minval(a(indx((right+1):size(a))))` !! then one or both bounds can be specified to reduce the search !! time. These constraints are available if we have previously !! called the subroutine with a different `k` (due to the way that !! `indx(:)` becomes partially sorted, see documentation for `indx(:)`). ${inttype}$ :: l, r, mid, iPivot integer, parameter :: ip = ${intkind}$ l = 1_ip if(present(left)) l = left r = size(a, kind=ip) if(present(right)) r = right if(size(a) /= size(indx)) then error stop "arg_select must have size(a) == size(indx)" end if if(l > r .or. l < 1_ip .or. r > size(a, kind=ip) & .or. k < l .or. k > r & !i.e. if k is not in the interval [l; r] ) then error stop "arg_select must have 1 <= left <= k <= right <= size(a)"; end if searchk: do mid = l + ((r-l)/2_ip) ! Avoid (l+r)/2 which can cause overflow call arg_medianOf3(a, indx, l, mid, r) call swap(indx(l), indx(mid)) call arg_partition(a, indx, l, r, iPivot) if (iPivot < k) then l = iPivot + 1_ip elseif (iPivot > k) then r = iPivot - 1_ip elseif (iPivot == k) then kth_smallest = indx(k) return end if end do searchk contains pure subroutine swap(a, b) ${inttype}$, intent(inout) :: a, b ${inttype}$ :: tmp tmp = a; a = b; b = tmp end subroutine pure subroutine arg_medianOf3(a, indx, left, mid, right) ${arraytype}$, intent(in) :: a(:) ${inttype}$, intent(inout) :: indx(:) ${inttype}$, intent(in) :: left, mid, right if(a(indx(right)) < a(indx(left))) call swap(indx(right), indx(left)) if(a(indx(mid)) < a(indx(left))) call swap(indx(mid) , indx(left)) if(a(indx(right)) < a(indx(mid)) ) call swap(indx(mid) , indx(right)) end subroutine pure subroutine arg_partition(array, indx, left,right,iPivot) ${arraytype}$, intent(in) :: array(:) ${inttype}$, intent(inout) :: indx(:) ${inttype}$, intent(in) :: left, right ${inttype}$, intent(out) :: iPivot ${inttype}$ :: lo,hi ${arraytype}$ :: pivot pivot = array(indx(left)) lo = left hi = right do while (lo <= hi) do while (array(indx(hi)) > pivot) hi=hi-1_ip end do inner_lohi: do while (lo <= hi ) if(array(indx(lo)) > pivot) exit inner_lohi lo=lo+1_ip end do inner_lohi if (lo <= hi) then call swap(indx(lo),indx(hi)) lo=lo+1_ip hi=hi-1_ip end if end do call swap(indx(left),indx(hi)) iPivot=hi end subroutine end subroutine #:endfor #:endfor end module