From d01f657ceef1eb96db5911da0f1899f84ed4a470 Mon Sep 17 00:00:00 2001 From: Ivan Date: Wed, 8 Apr 2020 01:16:31 +0200 Subject: [PATCH] Implementation and tests for diag, eye, and trace. Co-Authored-By: Jeremie Vandenplas --- src/CMakeLists.txt | 2 + src/Makefile.manual | 5 + src/stdlib_experimental_linalg.fypp | 80 ++++ src/stdlib_experimental_linalg.md | 156 ++++++++ src/stdlib_experimental_linalg_diag.fypp | 80 ++++ src/tests/CMakeLists.txt | 1 + src/tests/linalg/CMakeLists.txt | 2 + src/tests/linalg/test_linalg.f90 | 445 +++++++++++++++++++++++ 8 files changed, 771 insertions(+) create mode 100644 src/stdlib_experimental_linalg.fypp create mode 100644 src/stdlib_experimental_linalg.md create mode 100644 src/stdlib_experimental_linalg_diag.fypp create mode 100644 src/tests/linalg/CMakeLists.txt create mode 100644 src/tests/linalg/test_linalg.f90 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 043e6932d..0b13a9f93 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,6 +3,8 @@ # Create a list of the files to be preprocessed set(fppFiles stdlib_experimental_io.fypp + stdlib_experimental_linalg.fypp + stdlib_experimental_linalg_diag.fypp stdlib_experimental_optval.fypp stdlib_experimental_stats.fypp stdlib_experimental_stats_mean.fypp diff --git a/src/Makefile.manual b/src/Makefile.manual index 7cc8816ea..55f0352ed 100644 --- a/src/Makefile.manual +++ b/src/Makefile.manual @@ -2,6 +2,8 @@ SRC = f18estop.f90 \ stdlib_experimental_ascii.f90 \ stdlib_experimental_error.f90 \ stdlib_experimental_io.f90 \ + stdlib_experimental_linalg.f90 \ + stdlib_experimental_linalg_diag.f90 \ stdlib_experimental_kinds.f90 \ stdlib_experimental_optval.f90 \ stdlib_experimental_quadrature.f90 \ @@ -42,6 +44,7 @@ stdlib_experimental_io.o: \ stdlib_experimental_error.o \ stdlib_experimental_optval.o \ stdlib_experimental_kinds.o +stdlib_experimental_linalg_diag.o: stdlib_experimental_kinds.o stdlib_experimental_optval.o: stdlib_experimental_kinds.o stdlib_experimental_quadrature.o: stdlib_experimental_kinds.o stdlib_experimental_stats_mean.o: \ @@ -59,6 +62,8 @@ stdlib_experimental_stats_var.o: \ # Fortran sources that are built from fypp templates stdlib_experimental_io.f90: stdlib_experimental_io.fypp +stdlib_experimental_linalg.f90: stdlib_experimental_linalg.fypp +stdlib_experimental_linalg_diag.f90: stdlib_experimental_linalg_diag.fypp stdlib_experimental_quadrature.f90: stdlib_experimental_quadrature.fypp stdlib_experimental_stats.f90: stdlib_experimental_stats.fypp stdlib_experimental_stats_mean.f90: stdlib_experimental_stats_mean.fypp diff --git a/src/stdlib_experimental_linalg.fypp b/src/stdlib_experimental_linalg.fypp new file mode 100644 index 000000000..b3c3730d1 --- /dev/null +++ b/src/stdlib_experimental_linalg.fypp @@ -0,0 +1,80 @@ +#:include "common.fypp" +#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES +module stdlib_experimental_linalg + use stdlib_experimental_kinds, only: sp, dp, qp, & + int8, int16, int32, int64 + implicit none + private + + public :: diag + public :: eye + public :: trace + + interface diag + ! + ! Vector to matrix + ! + #:for k1, t1 in RCI_KINDS_TYPES + module function diag_${t1[0]}$${k1}$(v) result(res) + ${t1}$, intent(in) :: v(:) + ${t1}$ :: res(size(v),size(v)) + end function diag_${t1[0]}$${k1}$ + #:endfor + #:for k1, t1 in RCI_KINDS_TYPES + module function diag_${t1[0]}$${k1}$_k(v,k) result(res) + ${t1}$, intent(in) :: v(:) + integer, intent(in) :: k + ${t1}$ :: res(size(v)+abs(k),size(v)+abs(k)) + end function diag_${t1[0]}$${k1}$_k + #:endfor + + ! + ! Matrix to vector + ! + #:for k1, t1 in RCI_KINDS_TYPES + module function diag_${t1[0]}$${k1}$_mat(A) result(res) + ${t1}$, intent(in) :: A(:,:) + ${t1}$ :: res(minval(shape(A))) + end function diag_${t1[0]}$${k1}$_mat + #:endfor + #:for k1, t1 in RCI_KINDS_TYPES + module function diag_${t1[0]}$${k1}$_mat_k(A,k) result(res) + ${t1}$, intent(in) :: A(:,:) + integer, intent(in) :: k + ${t1}$ :: res(minval(shape(A))-abs(k)) + end function diag_${t1[0]}$${k1}$_mat_k + #:endfor + end interface + + ! Matrix trace + interface trace + #:for k1, t1 in RCI_KINDS_TYPES + module procedure trace_${t1[0]}$${k1}$ + #:endfor + end interface + +contains + + function eye(n) result(res) + integer, intent(in) :: n + integer(int8) :: res(n, n) + integer :: i + res = 0 + do i = 1, n + res(i, i) = 1 + end do + end function eye + + + #:for k1, t1 in RCI_KINDS_TYPES + function trace_${t1[0]}$${k1}$(A) result(res) + ${t1}$, intent(in) :: A(:,:) + ${t1}$ :: res + integer :: i + res = 0 + do i = 1, minval(shape(A)) + res = res + A(i,i) + end do + end function trace_${t1[0]}$${k1}$ + #:endfor +end module diff --git a/src/stdlib_experimental_linalg.md b/src/stdlib_experimental_linalg.md new file mode 100644 index 000000000..c59e5020d --- /dev/null +++ b/src/stdlib_experimental_linalg.md @@ -0,0 +1,156 @@ +# Linear Algebra + +* [`diag` - Create a diagonal array or extract the diagonal elements of an array](#diag---create-a-diagonal-array-or-extract-the-diagonal-elements-of-an-array) +* [`eye` - Construct the identity matrix](#eye---construct-the-identity-matrix) +* [`trace` - Trace of a matrix](#trace---trace-of-a-matrix) + +## `diag` - Create a diagonal array or extract the diagonal elements of an array + +### Description + +Create a diagonal array or extract the diagonal elements of an array + +### Syntax + +`d = diag(a [, k])` + +### Arguments + +`a`: Shall be a rank-1 or or rank-2 array. If `a` is a rank-1 array (i.e. a vector) then `diag` returns a rank-2 array with the elements of `a` on the diagonal. If `a` is a rank-2 array (i.e. a matrix) then `diag` returns a rank-1 array of the diagonal elements. + +`k` (optional): Shall be a scalar of type `integer` and specifies the diagonal. The default `k = 0` represents the main diagonal, `k > 0` are diagonals above the main diagonal, `k < 0` are diagonals below the main diagonal. + +### Return value + +Returns a diagonal array or a vector with the extracted diagonal elements. + +### Example + +```fortran +program demo_diag1 + use stdlib_experimental_linalg, only: diag + implicit none + real, allocatable :: A(:,:) + integer :: i + A = diag([(1,i=1,10)]) ! creates a 10 by 10 identity matrix +end program demo_diag1 +``` + +```fortran +program demo_diag2 + use stdlib_experimental_linalg, only: diag + implicit none + real :: v(:) + real, allocatable :: A(:,:) + integer :: i + v = [1,2,3,4,5] + A = diag(v) ! creates a 5 by 5 matrix with elements of v on the diagonal +end program demo_diag2 +``` + +```fortran +program demo_diag3 + use stdlib_experimental_linalg, only: diag + implicit none + integer, parameter :: n = 10 + real :: c(n), ul(n-1) + real :: A(n,n) + integer :: i + c = 2 + ul = -1 + A = diag(ul,-1) + diag(c) + diag(ul,1) ! Gil Strang's favorite matrix +end program demo_diag3 +``` + +```fortran +program demo_diag4 + use stdlib_experimental_linalg, only: diag + implicit none + integer, parameter :: n = 12 + real :: A(n,n) + real :: v(n) + integer :: i + call random_number(A) + v = diag(A) ! v contains diagonal elements of A +end program demo_diag4 +``` + +```fortran +program demo_diag5 + use stdlib_experimental_linalg, only: diag + implicit none + integer, parameter :: n = 3 + real :: A(n,n) + real, allocatable :: v(:) + integer :: i + A = reshape([1,2,3,4,5,6,7,8,9],[n,n]) + v = diag(A,-1) ! v is [2,6] + v = diag(A,1) ! v is [4,8] +end program demo_diag5 +``` + +## `eye` - Construct the identity matrix + +### Description + +Construct the identity matrix + +## Syntax + +`I = eye(n)` + +### Arguments + +`n`: Shall be a scalar of default type `integer`. + +### Return value + +Returns the identity matrix, i.e. a square matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`. + +### Example + +```fortran +program demo_eye1 + use stdlib_experimental_linalg, only: eye + implicit none + real :: a(3,3) + A = eye(3) +end program demo_eye1 +``` + +```fortran +program demo_eye2 + use stdlib_experimental_linalg, only: eye, diag + implicit none + print *, all(eye(4) == diag([1,1,1,1])) ! prints .true. +end program demo_eye2 +``` + +## `trace` - Trace of a matrix + +### Description + +Trace of a matrix (rank-2 array) + +### Syntax + +`result = trace(A)` + +### Arguments + +`A`: Shall be a rank-2 array. If `A` is not square, then `trace(A)` will return the sum of diagonal values from the square sub-section of `A`. + +### Return value + +Returns the trace of the matrix, i.e. the sum of diagonal elements. + +### Example +```fortran +program demo_trace + use stdlib_experimental_linalg, only: trace + implicit none + real :: A(3,3) + A = reshape([1,2,3,4,5,6,7,8,9],[3,3]) + print *, trace(A) ! 1 + 5 + 9 +end program demo_trace +``` diff --git a/src/stdlib_experimental_linalg_diag.fypp b/src/stdlib_experimental_linalg_diag.fypp new file mode 100644 index 000000000..d7773fa0e --- /dev/null +++ b/src/stdlib_experimental_linalg_diag.fypp @@ -0,0 +1,80 @@ +#:include "common.fypp" +#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES +submodule (stdlib_experimental_linalg) stdlib_experimental_linalg_diag + + implicit none + +contains + + #:for k1, t1 in RCI_KINDS_TYPES + function diag_${t1[0]}$${k1}$(v) result(res) + ${t1}$, intent(in) :: v(:) + ${t1}$ :: res(size(v),size(v)) + integer :: i + res = 0 + do i = 1, size(v) + res(i,i) = v(i) + end do + end function diag_${t1[0]}$${k1}$ + #:endfor + + + #:for k1, t1 in RCI_KINDS_TYPES + function diag_${t1[0]}$${k1}$_k(v,k) result(res) + ${t1}$, intent(in) :: v(:) + integer, intent(in) :: k + ${t1}$ :: res(size(v)+abs(k),size(v)+abs(k)) + integer :: i, sz + sz = size(v) + res = 0 + if (k > 0) then + do i = 1, sz + res(i,k+i) = v(i) + end do + else if (k < 0) then + do i = 1, sz + res(i+abs(k),i) = v(i) + end do + else + do i = 1, sz + res(i,i) = v(i) + end do + end if + end function diag_${t1[0]}$${k1}$_k + #:endfor + + #:for k1, t1 in RCI_KINDS_TYPES + function diag_${t1[0]}$${k1}$_mat(A) result(res) + ${t1}$, intent(in) :: A(:,:) + ${t1}$ :: res(minval(shape(A))) + integer :: i + do i = 1, minval(shape(A)) + res(i) = A(i,i) + end do + end function diag_${t1[0]}$${k1}$_mat + #:endfor + + #:for k1, t1 in RCI_KINDS_TYPES + function diag_${t1[0]}$${k1}$_mat_k(A,k) result(res) + ${t1}$, intent(in) :: A(:,:) + integer, intent(in) :: k + ${t1}$ :: res(minval(shape(A))-abs(k)) + integer :: i, sz + sz = minval(shape(A))-abs(k) + if (k > 0) then + do i = 1, sz + res(i) = A(i,k+i) + end do + else if (k < 0) then + do i = 1, sz + res(i) = A(i+abs(k),i) + end do + else + do i = 1, sz + res(i) = A(i,i) + end do + end if + end function diag_${t1[0]}$${k1}$_mat_k + #:endfor + +end submodule diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index f3b7d434c..593d261b6 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -8,6 +8,7 @@ endmacro(ADDTEST) add_subdirectory(ascii) add_subdirectory(io) +add_subdirectory(linalg) add_subdirectory(optval) add_subdirectory(stats) add_subdirectory(system) diff --git a/src/tests/linalg/CMakeLists.txt b/src/tests/linalg/CMakeLists.txt new file mode 100644 index 000000000..f1098405b --- /dev/null +++ b/src/tests/linalg/CMakeLists.txt @@ -0,0 +1,2 @@ +ADDTEST(linalg) + diff --git a/src/tests/linalg/test_linalg.f90 b/src/tests/linalg/test_linalg.f90 new file mode 100644 index 000000000..fa0c79a6e --- /dev/null +++ b/src/tests/linalg/test_linalg.f90 @@ -0,0 +1,445 @@ +program test_linalg + + use stdlib_experimental_error, only: check + use stdlib_experimental_kinds, only: sp, dp, qp, int8, int16, int32, int64 + use stdlib_experimental_linalg, only: diag, eye, trace + + implicit none + + real(sp), parameter :: sptol = 1000 * epsilon(1._sp) + real(dp), parameter :: dptol = 1000 * epsilon(1._dp) + real(qp), parameter :: qptol = 1000 * epsilon(1._qp) + + logical :: warn + + ! whether calls to check issue a warning + ! or stop execution + warn = .false. + + ! + ! eye + ! + call test_eye + + ! + ! diag + ! + call test_diag_rsp + call test_diag_rsp_k + call test_diag_rdp + call test_diag_rqp + + call test_diag_csp + call test_diag_cdp + call test_diag_cqp + + call test_diag_int8 + call test_diag_int16 + call test_diag_int32 + call test_diag_int64 + + ! + ! trace + ! + call test_trace_rsp + call test_trace_rsp_nonsquare + call test_trace_rdp + call test_trace_rdp_nonsquare + call test_trace_rqp + + call test_trace_csp + call test_trace_cdp + call test_trace_cqp + + call test_trace_int8 + call test_trace_int16 + call test_trace_int32 + call test_trace_int64 + + +contains + + subroutine test_eye + real(sp), allocatable :: rye(:,:) + complex(sp) :: cye(7,7) + integer :: i + write(*,*) "test_eye" + + call check(all(eye(5) == diag([(1,i=1,5)])), & + msg="all(eye(5) == diag([(1,i=1,5)] failed.",warn=warn) + + rye = eye(6) + call check(sum(rye - diag([(1.0_sp,i=1,6)])) < sptol, & + msg="sum(rye - diag([(1.0_sp,i=1,6)])) < sptol failed.",warn=warn) + + cye = eye(7) + call check(abs(trace(cye) - complex(7.0_sp,0.0_sp)) < sptol, & + msg="abs(trace(cye) - complex(7.0_sp,0.0_sp)) < sptol failed.",warn=warn) + end subroutine + + subroutine test_diag_rsp + integer, parameter :: n = 3 + real(sp) :: v(n), a(n,n), b(n,n) + integer :: i,j + write(*,*) "test_diag_rsp" + v = [(i,i=1,n)] + a = diag(v) + b = reshape([((merge(i,0,i==j), i=1,n), j=1,n)], [n,n]) + call check(all(a == b), & + msg="all(a == b) failed.",warn=warn) + + call check(all(diag(3*a) == 3*v), & + msg="all(diag(3*a) == 3*v) failed.",warn=warn) + end subroutine + + subroutine test_diag_rsp_k + integer, parameter :: n = 4 + real(sp) :: a(n,n), b(n,n) + integer :: i,j + write(*,*) "test_diag_rsp_k" + + a = diag([(1._sp,i=1,n-1)],-1) + + b = reshape([((merge(1,0,i==j+1), i=1,n), j=1,n)], [n,n]) + + call check(all(a == b), & + msg="all(a == b) failed.",warn=warn) + + call check(sum(diag(a,-1)) - (n-1) < sptol, & + msg="sum(diag(a,-1)) - (n-1) < sptol failed.",warn=warn) + + call check(all(a == transpose(diag([(1._sp,i=1,n-1)],1))), & + msg="all(a == transpose(diag([(1._sp,i=1,n-1)],1))) failed",warn=warn) + + call random_number(a) + do i = 1, n + call check(size(diag(a,i)) == n-i, & + msg="size(diag(a,i)) == n-i failed.",warn=warn) + end do + call check(size(diag(a,n+1)) == 0, & + msg="size(diag(a,n+1)) == 0 failed.",warn=warn) + end subroutine + + subroutine test_diag_rdp + integer, parameter :: n = 3 + real(dp) :: v(n), a(n,n), b(n,n) + integer :: i,j + write(*,*) "test_diag_rdp" + v = [(i,i=1,n)] + a = diag(v) + b = reshape([((merge(i,0,i==j), i=1,n), j=1,n)], [n,n]) + call check(all(a == b), & + msg="all(a == b) failed.",warn=warn) + + call check(all(diag(3*a) == 3*v), & + msg="all(diag(3*a) == 3*v) failed.",warn=warn) + end subroutine + + subroutine test_diag_rqp + integer, parameter :: n = 3 + real(qp) :: v(n), a(n,n), b(n,n) + integer :: i,j + write(*,*) "test_diag_rqp" + v = [(i,i=1,n)] + a = diag(v) + b = reshape([((merge(i,0,i==j), i=1,n), j=1,n)], [n,n]) + call check(all(a == b), & + msg="all(a == b) failed.", warn=warn) + + call check(all(diag(3*a) == 3*v), & + msg="all(diag(3*a) == 3*v) failed.", warn=warn) + end subroutine + + subroutine test_diag_csp + integer, parameter :: n = 3 + complex(sp) :: v(n), a(n,n), b(n,n) + complex(sp), parameter :: i_ = complex(0,1) + integer :: i,j + write(*,*) "test_diag_csp" + a = diag([(i,i=1,n)]) + diag([(i_,i=1,n)]) + b = reshape([((merge(i + 1*i_,0*i_,i==j), i=1,n), j=1,n)], [n,n]) + call check(all(a == b), & + msg="all(a == b) failed.",warn=warn) + + call check(all(abs(real(diag(a)) - [(i,i=1,n)]) < sptol), & + msg="all(abs(real(diag(a)) - [(i,i=1,n)]) < sptol)", warn=warn) + call check(all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol), & + msg="all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol)", warn=warn) + end subroutine + + subroutine test_diag_cdp + integer, parameter :: n = 3 + complex(dp) :: v(n), a(n,n), b(n,n) + complex(dp), parameter :: i_ = complex(0,1) + integer :: i,j + write(*,*) "test_diag_cdp" + a = diag([i_],-2) + diag([i_],2) + call check(a(3,1) == i_ .and. a(1,3) == i_, & + msg="a(3,1) == i_ .and. a(1,3) == i_ failed.",warn=warn) + end subroutine + + subroutine test_diag_cqp + integer, parameter :: n = 3 + complex(qp) :: v(n), a(n,n), b(n,n) + complex(qp), parameter :: i_ = complex(0,1) + integer :: i,j + write(*,*) "test_diag_cqp" + a = diag([i_,i_],-1) + diag([i_,i_],1) + call check(all(diag(a,-1) == i_) .and. all(diag(a,1) == i_), & + msg="all(diag(a,-1) == i_) .and. all(diag(a,1) == i_) failed.",warn=warn) + end subroutine + + subroutine test_diag_int8 + integer, parameter :: n = 3 + integer(int8), allocatable :: a(:,:) + integer :: i + logical, allocatable :: mask(:,:) + write(*,*) "test_diag_int8" + a = reshape([(i,i=1,n**2)],[n,n]) + mask = merge(.true.,.false.,eye(n) == 1) + call check(all(diag(a) == pack(a,mask)), & + msg="all(diag(a) == pack(a,mask)) failed.", warn=warn) + call check(all(diag(diag(a)) == merge(a,0_int8,mask)), & + msg="all(diag(diag(a)) == merge(a,0_int8,mask)) failed.", warn=warn) + end subroutine + subroutine test_diag_int16 + integer, parameter :: n = 4 + integer(int16), allocatable :: a(:,:) + integer :: i + logical, allocatable :: mask(:,:) + write(*,*) "test_diag_int16" + a = reshape([(i,i=1,n**2)],[n,n]) + mask = merge(.true.,.false.,eye(n) == 1) + call check(all(diag(a) == pack(a,mask)), & + msg="all(diag(a) == pack(a,mask))", warn=warn) + call check(all(diag(diag(a)) == merge(a,0_int16,mask)), & + msg="all(diag(diag(a)) == merge(a,0_int16,mask)) failed.", warn=warn) + end subroutine + subroutine test_diag_int32 + integer, parameter :: n = 3 + integer(int32) :: a(n,n) + logical :: mask(n,n) + integer :: i, j + write(*,*) "test_diag_int32" + mask = reshape([((merge(.true.,.false.,i==j+1), i=1,n), j=1,n)], [n,n]) + a = 0 + a = unpack([1_int32,1_int32],mask,a) + call check(all(diag([1,1],-1) == a), & + msg="all(diag([1,1],-1) == a) failed.", warn=warn) + call check(all(diag([1,1],1) == transpose(a)), & + msg="all(diag([1,1],1) == transpose(a)) failed.", warn=warn) + end subroutine + subroutine test_diag_int64 + integer, parameter :: n = 4 + integer(int64) :: a(n,n), c(0:2*n-1) + logical :: mask(n,n) + integer :: i, j + + write(*,*) "test_diag_int64" + + mask = reshape([((merge(.true.,.false.,i+1==j), i=1,n), j=1,n)], [n,n]) + a = 0 + a = unpack([1_int64,1_int64,1_int64],mask,a) + + call check(all(diag([1,1,1],1) == a), & + msg="all(diag([1,1,1],1) == a) failed.", warn=warn) + call check(all(diag([1,1,1],-1) == transpose(a)), & + msg="all(diag([1,1,1],-1) == transpose(a)) failed.", warn=warn) + + + ! Fill array c with Catalan numbers + do i = 0, 2*n-1 + c(i) = catalan_number(i) + end do + ! Symmetric Hankel matrix filled with Catalan numbers (det(H) = 1) + do i = 1, n + do j = 1, n + a(i,j) = c(i-1 + (j-1)) + end do + end do + call check(all(diag(a,-2) == diag(a,2)), & + msg="all(diag(a,-2) == diag(a,2))", warn=warn) + end subroutine + + + + + subroutine test_trace_rsp + integer, parameter :: n = 5 + real(sp) :: a(n,n) + integer :: i + write(*,*) "test_trace_rsp" + a = reshape([(i,i=1,n**2)],[n,n]) + call check(abs(trace(a) - sum(diag(a))) < sptol, & + msg="abs(trace(a) - sum(diag(a))) < sptol failed.",warn=warn) + end subroutine + + subroutine test_trace_rsp_nonsquare + integer, parameter :: n = 4 + real(sp) :: a(n,n+1), ans + integer :: i + write(*,*) "test_trace_rsp_nonsquare" + + ! 1 5 9 13 17 + ! 2 6 10 14 18 + ! 3 7 11 15 19 + ! 4 8 12 16 20 + a = reshape([(i,i=1,n*(n+1))],[n,n+1]) + ans = sum([1._sp,6._sp,11._sp,16._sp]) + + call check(abs(trace(a) - ans) < sptol, & + msg="abs(trace(a) - ans) < sptol failed.",warn=warn) + end subroutine + + subroutine test_trace_rdp + integer, parameter :: n = 4 + real(dp) :: a(n,n) + integer :: i + write(*,*) "test_trace_rdp" + a = reshape([(i,i=1,n**2)],[n,n]) + call check(abs(trace(a) - sum(diag(a))) < dptol, & + msg="abs(trace(a) - sum(diag(a))) < dptol failed.",warn=warn) + end subroutine + + subroutine test_trace_rdp_nonsquare + integer, parameter :: n = 4 + real(dp) :: a(n,n-1), ans + integer :: i + write(*,*) "test_trace_rdp_nonsquare" + + ! 1 25 81 + ! 4 36 100 + ! 9 49 121 + ! 16 64 144 + a = reshape([(i**2,i=1,n*(n-1))],[n,n-1]) + ans = sum([1._dp,36._dp,121._dp]) + + call check(abs(trace(a) - ans) < dptol, & + msg="abs(trace(a) - ans) < dptol failed.",warn=warn) + end subroutine + + subroutine test_trace_rqp + integer, parameter :: n = 3 + real(qp) :: a(n,n) + integer :: i + write(*,*) "test_trace_rqp" + a = reshape([(i,i=1,n**2)],[n,n]) + call check(abs(trace(a) - sum(diag(a))) < qptol, & + msg="abs(trace(a) - sum(diag(a))) < qptol failed.",warn=warn) + end subroutine + + + subroutine test_trace_csp + integer, parameter :: n = 5 + real(sp) :: re(n,n), im(n,n) + complex(sp) :: a(n,n), b(n,n) + complex(sp), parameter :: i_ = complex(0,1) + write(*,*) "test_trace_csp" + + call random_number(re) + call random_number(im) + a = re + im*i_ + + call random_number(re) + call random_number(im) + b = re + im*i_ + + ! tr(A + B) = tr(A) + tr(B) + call check(abs(trace(a+b) - (trace(a) + trace(b))) < sptol, & + msg="abs(trace(a+b) - (trace(a) + trace(b))) < sptol failed.",warn=warn) + end subroutine + + subroutine test_trace_cdp + integer, parameter :: n = 3 + complex(dp) :: a(n,n), ans + complex(dp), parameter :: i_ = complex(0,1) + integer :: j + write(*,*) "test_trace_cdp" + + a = reshape([(j + (n**2 - (j-1))*i_,j=1,n**2)],[n,n]) + ans = complex(15,15) !(1 + 5 + 9) + (9 + 5 + 1)i + + call check(abs(trace(a) - ans) < dptol, & + msg="abs(trace(a) - ans) < dptol failed.",warn=warn) + end subroutine + + subroutine test_trace_cqp + integer, parameter :: n = 3 + complex(qp) :: a(n,n) + complex(qp), parameter :: i_ = complex(0,1) + write(*,*) "test_trace_cqp" + a = 3*eye(n) + 4*eye(n)*i_ ! pythagorean triple + call check(abs(trace(a)) - 3*5.0_qp < qptol, & + msg="abs(trace(a)) - 3*5.0_qp < qptol failed.",warn=warn) + end subroutine + + + subroutine test_trace_int8 + integer, parameter :: n = 3 + integer(int8) :: a(n,n) + integer :: i + write(*,*) "test_trace_int8" + a = reshape([(i**2,i=1,n**2)],[n,n]) + call check(trace(a) == (1 + 25 + 81), & + msg="trace(a) == (1 + 25 + 81) failed.",warn=warn) + end subroutine + + subroutine test_trace_int16 + integer, parameter :: n = 3 + integer(int16) :: a(n,n) + integer :: i + write(*,*) "test_trace_int16" + a = reshape([(i**3,i=1,n**2)],[n,n]) + call check(trace(a) == (1 + 125 + 729), & + msg="trace(a) == (1 + 125 + 729) failed.",warn=warn) + end subroutine + + subroutine test_trace_int32 + integer, parameter :: n = 3 + integer(int32) :: a(n,n) + integer :: i + write(*,*) "test_trace_int32" + a = reshape([(i**4,i=1,n**2)],[n,n]) + call check(trace(a) == (1 + 625 + 6561), & + msg="trace(a) == (1 + 625 + 6561) failed.",warn=warn) + end subroutine + + subroutine test_trace_int64 + integer, parameter :: n = 5 + integer, parameter :: nd = 2*n-1 ! number of diagonals + integer :: i, j + integer(int64) :: c(0:nd), H(n,n) + write(*,*) "test_trace_int64" + + ! Fill array with Catalan numbers + do i = 0, nd + c(i) = catalan_number(i) + end do + + ! Symmetric Hankel matrix filled with Catalan numbers (det(H) = 1) + do i = 1, n + do j = 1, n + H(i,j) = c(i-1 + (j-1)) + end do + end do + + call check(trace(h) == sum(c(0:nd:2)), & + msg="trace(h) == sum(c(0:nd:2)) failed.",warn=warn) + + end subroutine + + pure recursive function catalan_number(n) result(value) + integer, intent(in) :: n + integer :: value + integer :: i + if (n <= 1) then + value = 1 + else + value = 0 + do i = 0, n-1 + value = value + catalan_number(i)*catalan_number(n-i-1) + end do + end if + end function + +end program \ No newline at end of file