stdlib_selection.fypp Source File


Source Code

#:include "common.fypp"
! Specify kinds/types for the input array in select and arg_select
#:set ARRAY_KINDS_TYPES = INT_KINDS_TYPES + REAL_KINDS_TYPES
! The index arrays are of all INT_KINDS_TYPES

module stdlib_selection
!! Quickly find the k-th smallest value of an array, or the index of the k-th smallest value.
!! ([Specification](../page/specs/stdlib_selection.html))
!
! This code was modified from the "Coretran" implementation "quickSelect" by
! Leon Foks, https://github.com/leonfoks/coretran/tree/HEAD/src/sorting
!
! Leon Foks gave permission to release this code under stdlib's MIT license.
! (https://github.com/fortran-lang/stdlib/pull/500#commitcomment-57418593)
!

use stdlib_kinds

implicit none

private

public :: select, arg_select

interface select
    !! version: experimental
    !! ([Specification](..//page/specs/stdlib_selection.html#select-find-the-k-th-smallest-value-in-an-input-array))

  #:for arraykind, arraytype in ARRAY_KINDS_TYPES
    #:for intkind, inttype in INT_KINDS_TYPES
      #:set name = rname("select", 1, arraytype, arraykind, intkind)
      module procedure ${name}$
    #:endfor
  #:endfor
end interface

interface arg_select
    !! version: experimental
    !! ([Specification](..//page/specs/stdlib_selection.html#arg_select-find-the-index-of-the-k-th-smallest-value-in-an-input-array))
  #:for arraykind, arraytype in ARRAY_KINDS_TYPES
    #:for intkind, inttype in INT_KINDS_TYPES
      #:set name = rname("arg_select", 1, arraytype, arraykind, intkind)
      module procedure ${name}$
    #:endfor
  #:endfor
end interface

contains

  #:for arraykind, arraytype in ARRAY_KINDS_TYPES
    #:for intkind, inttype in INT_KINDS_TYPES
      #:set name = rname("select", 1, arraytype, arraykind, intkind)
      subroutine ${name}$(a, k, kth_smallest, left, right)
          !! select - select the k-th smallest entry in a(:).
          !!
          !! Partly derived from the "Coretran" implementation of 
          !! quickSelect by Leon Foks, https://github.com/leonfoks/coretran
          !!
          ${arraytype}$, intent(inout) :: a(:)
              !! Array in which we seek the k-th smallest entry.
              !! On output it will be partially sorted such that
              !! `all(a(1:(k-1)) <= a(k)) .and. all(a(k) <= a((k+1):size(a)))`
              !! is true.
          ${inttype}$, intent(in) :: k
              !! We want the k-th smallest entry. E.G. `k=1` leads to
              !! `kth_smallest=min(a)`, and `k=size(a)` leads to
              !! `kth_smallest=max(a)`
          ${arraytype}$, intent(out) :: kth_smallest
              !! On output contains the k-th smallest value of `a(:)`
          ${inttype}$, intent(in), optional :: left, right
              !! If we know that:
              !!    the k-th smallest entry of `a` is in `a(left:right)`
              !! and also that:
              !!    `maxval(a(1:(left-1))) <= minval(a(left:right))`
              !! and:
              !!    `maxval(a(left:right))) <= minval(a((right+1):size(a)))`
              !! then one or both bounds can be specified to narrow the search.
              !! The constraints are available if we have previously called the
              !! subroutine with different `k` (because of how `a(:)` becomes
              !! partially sorted, see documentation for `a(:)`).

          ${inttype}$ :: l, r, mid, iPivot
          integer, parameter :: ip = ${intkind}$

          l = 1_ip
          if(present(left)) l = left
          r = size(a, kind=ip)
          if(present(right)) r = right

          if(l > r .or. l < 1_ip .or. r > size(a, kind=ip) &
              .or. k < l .or. k > r                        & !i.e. if k is not in the interval [l; r]  
              ) then
              error stop "select must have  1 <= left <= k <= right <= size(a)";
          end if

          searchk: do
              mid = l + ((r-l)/2_ip) ! Avoid (l+r)/2 which can cause overflow

              call medianOf3(a, l, mid, r)
              call swap(a(l), a(mid))
              call partition(a, l, r, iPivot)

              if (iPivot < k) then
                l = iPivot + 1_ip
              elseif (iPivot > k) then
                r = iPivot - 1_ip
              elseif (iPivot == k) then
                kth_smallest = a(k)
                return
              end if
          end do searchk

          contains
              pure subroutine swap(a, b)
                  ${arraytype}$, intent(inout) :: a, b
                  ${arraytype}$ :: tmp
                  tmp = a; a = b; b = tmp
              end subroutine

              pure subroutine medianOf3(a, left, mid, right)
                  ${arraytype}$, intent(inout) :: a(:)
                  ${inttype}$, intent(in) :: left, mid, right 
                  if(a(right) < a(left)) call swap(a(right), a(left))
                  if(a(mid)   < a(left)) call swap(a(mid)  , a(left))
                  if(a(right) < a(mid) ) call swap(a(mid)  , a(right))
              end subroutine

              pure subroutine partition(array,left,right,iPivot)
                  ${arraytype}$, intent(inout) :: array(:)
                  ${inttype}$, intent(in) :: left, right
                  ${inttype}$, intent(out) :: iPivot 
                  ${inttype}$ :: lo,hi
                  ${arraytype}$ :: pivot

                  pivot = array(left)
                  lo = left
                  hi=right
                  do while (lo <= hi)
                    do while (array(hi) > pivot)
                      hi=hi-1_ip
                    end do
                    inner_lohi: do while (lo <= hi )
                       if(array(lo) > pivot) exit inner_lohi
                      lo=lo+1_ip
                    end do inner_lohi
                    if (lo <= hi) then
                      call swap(array(lo),array(hi))
                      lo=lo+1_ip
                      hi=hi-1_ip
                    end if
                  end do
                  call swap(array(left),array(hi))
                  iPivot=hi
              end subroutine
      end subroutine
    #:endfor
  #:endfor


  #:for arraykind, arraytype in ARRAY_KINDS_TYPES
    #:for intkind, inttype in INT_KINDS_TYPES
      #:set name = rname("arg_select", 1, arraytype, arraykind, intkind)
      subroutine ${name}$(a, indx, k, kth_smallest, left, right)
          !! arg_select - find the index of the k-th smallest entry in `a(:)`
          !!
          !! Partly derived from the "Coretran" implementation of 
          !! quickSelect by Leon Foks, https://github.com/leonfoks/coretran
          !!
          ${arraytype}$, intent(in) :: a(:)
              !! Array in which we seek the k-th smallest entry.
          ${inttype}$, intent(inout) :: indx(:)
              !! Array of indices into `a(:)`. Must contain each integer
              !! from `1:size(a)` exactly once. On output it will be partially
              !! sorted such that
              !! `all( a(indx(1:(k-1)))) <= a(indx(k)) ) .AND.
              !!  all( a(indx(k))  <= a(indx( (k+1):size(a) )) )`.
          ${inttype}$, intent(in) :: k
              !! We want index of the k-th smallest entry. E.G. `k=1` leads to
              !! `a(kth_smallest) = min(a)`, and `k=size(a)` leads to
              !! `a(kth_smallest) = max(a)`
          ${inttype}$, intent(out) :: kth_smallest
              !! On output contains the index with the k-th smallest value of `a(:)`
          ${inttype}$, intent(in), optional :: left, right
              !! If we know that:
              !!  the k-th smallest entry of `a` is in `a(indx(left:right))`
              !! and also that:
              !!  `maxval(a(indx(1:(left-1)))) <= minval(a(indx(left:right)))`
              !! and:
              !!  `maxval(a(indx(left:right))) <= minval(a(indx((right+1):size(a))))`
              !! then one or both bounds can be specified to reduce the search
              !! time. These constraints are available if we have previously
              !! called the subroutine with a different `k` (due to the way that
              !! `indx(:)` becomes partially sorted, see documentation for `indx(:)`).

          ${inttype}$ :: l, r, mid, iPivot
          integer, parameter :: ip = ${intkind}$

          l = 1_ip
          if(present(left)) l = left
          r = size(a, kind=ip)
          if(present(right)) r = right

          if(size(a) /= size(indx)) then
              error stop "arg_select must have size(a) == size(indx)"
          end if

          if(l > r .or. l < 1_ip .or. r > size(a, kind=ip) &
              .or. k < l .or. k > r                        & !i.e. if k is not in the interval [l; r]  
              ) then
              error stop "arg_select must have 1 <= left <= k <= right <= size(a)";
          end if

          searchk: do
              mid = l + ((r-l)/2_ip) ! Avoid (l+r)/2 which can cause overflow

              call arg_medianOf3(a, indx, l, mid, r)
              call swap(indx(l), indx(mid))
              call arg_partition(a, indx, l, r, iPivot)

              if (iPivot < k) then
                l = iPivot + 1_ip
              elseif (iPivot > k) then
                r = iPivot - 1_ip
              elseif (iPivot == k) then
                kth_smallest = indx(k)
                return
              end if
          end do searchk

          contains
              pure subroutine swap(a, b)
                  ${inttype}$, intent(inout) :: a, b
                  ${inttype}$ :: tmp
                  tmp = a; a = b; b = tmp
              end subroutine

              pure subroutine arg_medianOf3(a, indx, left, mid, right)
                  ${arraytype}$, intent(in) :: a(:)
                  ${inttype}$, intent(inout) :: indx(:)
                  ${inttype}$, intent(in) :: left, mid, right 
                  if(a(indx(right)) < a(indx(left))) call swap(indx(right), indx(left))
                  if(a(indx(mid))   < a(indx(left))) call swap(indx(mid)  , indx(left))
                  if(a(indx(right)) < a(indx(mid)) ) call swap(indx(mid)  , indx(right))
              end subroutine

              pure subroutine arg_partition(array, indx, left,right,iPivot)
                  ${arraytype}$, intent(in) :: array(:)
                  ${inttype}$, intent(inout) :: indx(:)
                  ${inttype}$, intent(in) :: left, right
                  ${inttype}$, intent(out) :: iPivot 
                  ${inttype}$ :: lo,hi
                  ${arraytype}$ :: pivot

                  pivot = array(indx(left))
                  lo = left
                  hi = right
                  do while (lo <= hi)
                    do while (array(indx(hi)) > pivot)
                      hi=hi-1_ip
                    end do
                    inner_lohi: do while (lo <= hi )
                      if(array(indx(lo)) > pivot) exit inner_lohi
                      lo=lo+1_ip
                    end do inner_lohi
                    if (lo <= hi) then
                      call swap(indx(lo),indx(hi))
                      lo=lo+1_ip
                      hi=hi-1_ip
                    end if
                  end do
                  call swap(indx(left),indx(hi))
                  iPivot=hi
              end subroutine
      end subroutine
    #:endfor
  #:endfor

end module