! SPDX-Identifier: MIT #:include "common.fypp" #:set RANKS = range(1, MAXRANK + 1) #:set KINDS_TYPES = REAL_KINDS_TYPES + INT_KINDS_TYPES + CMPLX_KINDS_TYPES !> Implementation of loading npy files into multidimensional arrays submodule (stdlib_io_npy) stdlib_io_npy_load use stdlib_error, only : error_stop use stdlib_strings, only : to_string, starts_with implicit none contains #:for k1, t1 in KINDS_TYPES #:for rank in RANKS !> Load a ${rank}$-dimensional array from a npy file module subroutine load_npy_${t1[0]}$${k1}$_${rank}$(filename, array, iostat, iomsg) !> Name of the npy file to load from character(len=*), intent(in) :: filename !> Array to be loaded from the npy file ${t1}$, allocatable, intent(out) :: array${ranksuffix(rank)}$ !> Error status of loading, zero on success integer, intent(out), optional :: iostat !> Associated error message in case of non-zero status code character(len=:), allocatable, intent(out), optional :: iomsg character(len=*), parameter :: vtype = type_${t1[0]}$${k1}$ integer, parameter :: rank = ${rank}$ integer :: io, stat character(len=:), allocatable :: msg open(newunit=io, file=filename, form="unformatted", access="stream", iostat=stat) catch: block character(len=:), allocatable :: this_type integer, allocatable :: vshape(:) call get_descriptor(io, filename, this_type, vshape, stat, msg) if (stat /= 0) exit catch if (this_type /= vtype) then stat = 1 msg = "File '"//filename//"' contains data of type '"//this_type//"', "//& & "but expected '"//vtype//"'" exit catch end if if (size(vshape) /= rank) then stat = 1 msg = "File '"//filename//"' contains data of rank "//& & to_string(size(vshape))//", but expected "//& & to_string(rank) exit catch end if call allocator(array, vshape, stat) if (stat /= 0) then msg = "Failed to allocate array of type '"//vtype//"' "//& & "with total size of "//to_string(product(vshape)) exit catch end if read(io, iostat=stat) array end block catch close(io) if (present(iostat)) then iostat = stat else if (stat /= 0) then if (allocated(msg)) then call error_stop("Failed to read array from file '"//filename//"'"//nl//& & msg) else call error_stop("Failed to read array from file '"//filename//"'") end if end if if (present(iomsg).and.allocated(msg)) call move_alloc(msg, iomsg) contains !> Wrapped intrinsic allocate to create an allocation from a shape array subroutine allocator(array, vshape, stat) !> Instance of the array to be allocated ${t1}$, allocatable, intent(out) :: array${ranksuffix(rank)}$ !> Dimensions to allocate for integer, intent(in) :: vshape(:) !> Status of allocate integer, intent(out) :: stat allocate(array( & #:for i in range(rank-1) & vshape(${i+1}$), & #:endfor & vshape(${rank}$)), & & stat=stat) end subroutine allocator end subroutine load_npy_${t1[0]}$${k1}$_${rank}$ #:endfor #:endfor !> Read the npy header from a binary file and retrieve the descriptor string. subroutine get_descriptor(io, filename, vtype, vshape, stat, msg) !> Unformatted, stream accessed unit integer, intent(in) :: io !> Filename for error reporting character(len=*), intent(in) :: filename !> Type of data saved in npy file character(len=:), allocatable, intent(out) :: vtype !> Shape descriptor of the integer, allocatable, intent(out) :: vshape(:) !> Status of operation integer, intent(out) :: stat !> Associated error message in case of non-zero status character(len=:), allocatable, intent(out) :: msg integer :: major, header_len, i character(len=:), allocatable :: dict character(len=8) :: header character :: buf(4) logical :: fortran_order ! stat should be zero if no error occurred stat = 0 read(io, iostat=stat) header if (stat /= 0) return call parse_header(header, major, stat, msg) if (stat /= 0) return read(io, iostat=stat) buf(1:merge(4, 2, major > 1)) if (stat /= 0) return if (major > 1) then header_len = ichar(buf(1)) & & + ichar(buf(2)) * 256**1 & & + ichar(buf(3)) * 256**2 & & + ichar(buf(4)) * 256**3 else header_len = ichar(buf(1)) & & + ichar(buf(2)) * 256**1 end if allocate(character(header_len) :: dict, stat=stat) if (stat /= 0) return read(io, iostat=stat) dict if (stat /= 0) return if (dict(header_len:header_len) /= nl) then stat = 1 msg = "Descriptor length does not match" return end if if (scan(dict, achar(0)) > 0) then stat = 1 msg = "Nul byte not allowed in descriptor string" return end if call parse_descriptor(trim(dict(:len(dict)-1)), filename, & & vtype, fortran_order, vshape, stat, msg) if (stat /= 0) return if (.not.fortran_order) then vshape = [(vshape(i), i = size(vshape), 1, -1)] end if end subroutine get_descriptor !> Parse the first eight bytes of the npy header to verify the data subroutine parse_header(header, major, stat, msg) !> Header of the binary file character(len=*), intent(in) :: header !> Major version of the npy format integer, intent(out) :: major !> Status of operation integer, intent(out) :: stat !> Associated error message in case of non-zero status character(len=:), allocatable, intent(out) :: msg integer :: minor ! stat should be zero if no error occurred stat = 0 if (header(1:1) /= magic_number) then stat = 1 msg = "Expected z'93' but got z'"//to_string(ichar(header(1:1)))//"' "//& & "as first byte" return end if if (header(2:6) /= magic_string) then stat = 1 msg = "Expected identifier '"//magic_string//"'" return end if major = ichar(header(7:7)) if (.not.any(major == [1, 2, 3])) then stat = 1 msg = "Unsupported format major version number '"//to_string(major)//"'" return end if minor = ichar(header(8:8)) if (minor /= 0) then stat = 1 msg = "Unsupported format version "// & & "'"//to_string(major)//"."//to_string(minor)//"'" return end if end subroutine parse_header !> Parse the descriptor in the npy header. This routine implements a minimal !> non-recursive parser for serialized Python dictionaries. subroutine parse_descriptor(input, filename, vtype, fortran_order, vshape, stat, msg) !> Input string to parse as descriptor character(len=*), intent(in) :: input !> Filename for error reporting character(len=*), intent(in) :: filename !> Type of the data stored, retrieved from field `descr` character(len=:), allocatable, intent(out) :: vtype !> Whether the data is in left layout, retrieved from field `fortran_order` logical, intent(out) :: fortran_order !> Shape of the stored data, retrieved from field `shape` integer, allocatable, intent(out) :: vshape(:) !> Status of operation integer, intent(out) :: stat !> Associated error message in case of non-zero status character(len=:), allocatable, intent(out) :: msg enum, bind(c) enumerator :: invalid, string, lbrace, rbrace, comma, colon, & lparen, rparen, bool, literal, space end enum type :: token_type integer :: first, last, kind end type token_type integer :: pos character(len=:), allocatable :: key type(token_type) :: token, last logical :: has_descr, has_shape, has_fortran_order has_descr = .false. has_shape = .false. has_fortran_order = .false. pos = 0 call next_token(input, pos, token, [lbrace], stat, msg) if (stat /= 0) return last = token_type(pos, pos, comma) do while (pos < len(input)) call get_token(input, pos, token) select case(token%kind) case(space) continue case(comma) if (token%kind == last%kind) then stat = 1 msg = make_message(filename, input, token%first, token%last, & & "Comma cannot appear at this point") return end if last = token case(rbrace) exit case(string) if (token%kind == last%kind) then stat = 1 msg = make_message(filename, input, token%first, token%last, & & "String cannot appear at this point") return end if last = token key = input(token%first+1:token%last-1) call next_token(input, pos, token, [colon], stat, msg) if (stat /= 0) return if (key == "descr" .and. has_descr & & .or. key == "fortran_order" .and. has_fortran_order & & .or. key == "shape" .and. has_shape) then stat = 1 msg = make_message(filename, input, last%first, last%last, & & "Duplicate entry for '"//key//"' found") return end if select case(key) case("descr") call next_token(input, pos, token, [string], stat, msg) if (stat /= 0) return vtype = input(token%first+1:token%last-1) has_descr = .true. case("fortran_order") call next_token(input, pos, token, [bool], stat, msg) if (stat /= 0) return fortran_order = input(token%first:token%last) == "True" has_fortran_order = .true. case("shape") call parse_tuple(input, pos, vshape, stat, msg) has_shape = .true. case default stat = 1 msg = make_message(filename, input, last%first, last%last, & & "Invalid entry '"//key//"' in dictionary encountered") return end select case default stat = 1 msg = make_message(filename, input, token%first, token%last, & & "Invalid token encountered") return end select end do if (.not.has_descr) then stat = 1 msg = make_message(filename, input, 1, pos, & & "Dictionary does not contain required entry 'descr'") end if if (.not.has_shape) then stat = 1 msg = make_message(filename, input, 1, pos, & & "Dictionary does not contain required entry 'shape'") end if if (.not.has_fortran_order) then stat = 1 msg = make_message(filename, input, 1, pos, & & "Dictionary does not contain required entry 'fortran_order'") end if contains function make_message(filename, input, first, last, message) result(str) !> Filename for context character(len=*), intent(in) :: filename !> Input string to parse character(len=*), intent(in) :: input !> Offset in the input integer, intent(in) :: first, last !> Error message character(len=*), intent(in) :: message !> Final output message character(len=:), allocatable :: str character(len=*), parameter :: nl = new_line('a') str = message // nl // & & " --> " // filename // ":1:" // to_string(first) // "-" // to_string(last) // nl // & & " |" // nl // & & "1 | " // input // nl // & & " |" // repeat(" ", first) // repeat("^", last - first + 1) // nl // & & " |" end function make_message !> Parse a tuple of integers into an array of integers subroutine parse_tuple(input, pos, tuple, stat, msg) !> Input string to parse character(len=*), intent(in) :: input !> Offset in the input, will be advanced after reading integer, intent(inout) :: pos !> Array representing tuple of integers integer, allocatable, intent(out) :: tuple(:) !> Status of operation integer, intent(out) :: stat !> Associated error message in case of non-zero status character(len=:), allocatable, intent(out) :: msg type(token_type) :: token integer :: last, itmp allocate(tuple(0), stat=stat) if (stat /= 0) return call next_token(input, pos, token, [lparen], stat, msg) if (stat /= 0) return last = comma do while (pos < len(input)) call get_token(input, pos, token) select case(token%kind) case(space) continue case(literal) if (token%kind == last) then stat = 1 msg = make_message(filename, input, token%first, token%last, & & "Invalid token encountered") return end if last = token%kind read(input(token%first:token%last), *, iostat=stat) itmp if (stat /= 0) then return end if tuple = [tuple, itmp] case(comma) if (token%kind == last) then stat = 1 msg = make_message(filename, input, token%first, token%last, & & "Invalid token encountered") return end if last = token%kind case(rparen) exit case default stat = 1 msg = make_message(filename, input, token%first, token%last, & & "Invalid token encountered") return end select end do end subroutine parse_tuple !> Get the next allowed token subroutine next_token(input, pos, token, allowed_token, stat, msg) !> Input string to parse character(len=*), intent(in) :: input !> Current offset in the input string integer, intent(inout) :: pos !> Last token parsed type(token_type), intent(out) :: token !> Tokens allowed in the current context integer, intent(in) :: allowed_token(:) !> Status of operation integer, intent(out) :: stat !> Associated error message in case of non-zero status character(len=:), allocatable, intent(out) :: msg stat = pos do while (pos < len(input)) call get_token(input, pos, token) if (token%kind == space) then continue else if (any(token%kind == allowed_token)) then stat = 0 exit else stat = 1 msg = make_message(filename, input, token%first, token%last, & & "Invalid token encountered") exit end if end do end subroutine next_token !> Tokenize input string subroutine get_token(input, pos, token) !> Input strin to tokenize character(len=*), intent(in) :: input !> Offset in input string, will be advanced integer, intent(inout) :: pos !> Returned token from the next position type(token_type), intent(out) :: token character :: quote pos = pos + 1 select case(input(pos:pos)) case("""", "'") quote = input(pos:pos) token%first = pos pos = pos + 1 do while (pos <= len(input)) if (input(pos:pos) == quote) then token%last = pos exit else pos = pos + 1 end if end do token%kind = string case("0", "1", "2", "3", "4", "5", "6", "7", "8", "9") token%first = pos do while (pos <= len(input)) if (.not.any(input(pos:pos) == ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"])) then pos = pos - 1 token%last = pos exit else pos = pos + 1 end if end do token%kind = literal case("T") if (starts_with(input(pos:), "True")) then token = token_type(pos, pos+3, bool) pos = pos + 3 else token = token_type(pos, pos, invalid) end if case("F") if (starts_with(input(pos:), "False")) then token = token_type(pos, pos+4, bool) pos = pos + 4 else token = token_type(pos, pos, invalid) end if case("{") token = token_type(pos, pos, lbrace) case("}") token = token_type(pos, pos, rbrace) case(",") token = token_type(pos, pos, comma) case(":") token = token_type(pos, pos, colon) case("(") token = token_type(pos, pos, lparen) case(")") token = token_type(pos, pos, rparen) case(" ", nl) token = token_type(pos, pos, space) case default token = token_type(pos, pos, invalid) end select end subroutine get_token end subroutine parse_descriptor end submodule stdlib_io_npy_load