From 511c0dbbe70557d002f568b3c382ab7dec1659f6 Mon Sep 17 00:00:00 2001 From: St Maxwell Date: Tue, 22 Nov 2022 23:21:44 +0800 Subject: [PATCH 1/4] cross product of two vectors --- doc/specs/stdlib_linalg.md | 30 +++++ example/linalg/example_cross_product.f90 | 9 ++ src/stdlib_linalg.fypp | 16 +++ src/stdlib_linalg_cross_product.fypp | 21 +++ test/linalg/test_linalg.fypp | 159 ++++++++++++++++++++++- 5 files changed, 234 insertions(+), 1 deletion(-) create mode 100644 example/linalg/example_cross_product.f90 create mode 100644 src/stdlib_linalg_cross_product.fypp diff --git a/doc/specs/stdlib_linalg.md b/doc/specs/stdlib_linalg.md index f7569e0c6..29eb9878d 100644 --- a/doc/specs/stdlib_linalg.md +++ b/doc/specs/stdlib_linalg.md @@ -160,6 +160,36 @@ Returns a rank-2 array equal to `u v^T` (where `u, v` are considered column vect {!example/linalg/example_outer_product.f90!} ``` +## `cross_product` - Computes the cross product of two vectors + +### Status + +Experimental + +### Description + +Computes the cross product of two vectors + +### Syntax + +`c = [[stdlib_linalg(module):cross_product(interface)]](a, b)` + +### Arguments + +`a`: Shall be a rank-1 and size-3 array + +`b`: Shall be a rank-1 and size-3 array + +### Return value + +Returns a rank-1 and size-3 array which is perpendicular to both `a` and `b`. + +### Example + +```fortran +{!example/linalg/example_cross_product.f90!} +``` + ## `is_square` - Checks if a matrix is square ### Status diff --git a/example/linalg/example_cross_product.f90 b/example/linalg/example_cross_product.f90 new file mode 100644 index 000000000..e546647f4 --- /dev/null +++ b/example/linalg/example_cross_product.f90 @@ -0,0 +1,9 @@ +program demo_cross_product + use stdlib_linalg, only: cross_product + implicit none + real :: a(3), b(3), c(3) + a = [1., 0., 0.] + b = [0., 1., 0.] + c = cross_product(a, b) + !c = [0., 0., 1.] +end program demo_cross_product diff --git a/src/stdlib_linalg.fypp b/src/stdlib_linalg.fypp index bc1017f0a..cfa43d3d9 100644 --- a/src/stdlib_linalg.fypp +++ b/src/stdlib_linalg.fypp @@ -14,6 +14,7 @@ module stdlib_linalg public :: eye public :: trace public :: outer_product + public :: cross_product public :: is_square public :: is_diagonal public :: is_symmetric @@ -93,6 +94,21 @@ module stdlib_linalg end interface outer_product + ! Cross product (of two vectors) + interface cross_product + !! version: experimental + !! + !! Computes the cross product of two vectors, returning a rank-1 and size-3 array + !! ([Specification](../page/specs/stdlib_linalg.html#cross_product-computes-the-cross-product-of-two-3-d-vectors)) + #:for k1, t1 in RCI_KINDS_TYPES + pure module function cross_product_${t1[0]}$${k1}$(a, b) result(res) + ${t1}$, intent(in) :: a(3), b(3) + ${t1}$ :: res(3) + end function cross_product_${t1[0]}$${k1}$ + #:endfor + end interface cross_product + + ! Check for squareness interface is_square !! version: experimental diff --git a/src/stdlib_linalg_cross_product.fypp b/src/stdlib_linalg_cross_product.fypp new file mode 100644 index 000000000..bc0afa0a0 --- /dev/null +++ b/src/stdlib_linalg_cross_product.fypp @@ -0,0 +1,21 @@ +#:include "common.fypp" +#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES +submodule (stdlib_linalg) stdlib_linalg_cross_product + use stdlib_error, only: error_stop + implicit none + +contains + + #:for k1, t1 in RCI_KINDS_TYPES + pure module function cross_product_${t1[0]}$${k1}$(a, b) result(res) + ${t1}$, intent(in) :: a(3), b(3) + ${t1}$ :: res(3) + + res(1) = a(2) * b(3) - a(3) * b(2) + res(2) = a(3) * b(1) - a(1) * b(3) + res(3) = a(1) * b(2) - a(2) * b(1) + + end function cross_product_${t1[0]}$${k1}$ + #:endfor + +end submodule diff --git a/test/linalg/test_linalg.fypp b/test/linalg/test_linalg.fypp index f74cbff6b..f3bbcdf2c 100644 --- a/test/linalg/test_linalg.fypp +++ b/test/linalg/test_linalg.fypp @@ -3,7 +3,7 @@ module test_linalg use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64 - use stdlib_linalg, only: diag, eye, trace, outer_product + use stdlib_linalg, only: diag, eye, trace, outer_product,cross_product implicit none @@ -702,6 +702,163 @@ contains "all(abs(diff) == 0) failed.") end subroutine test_outer_product_int64 + subroutine test_cross_product_int8(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + integer(int8) :: u(n), v(n), expected(n), diff(n) + + u = [1,0,0] + v = [0,1,0] + expected = [0,0,1] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) == 0), & + "all(abs(diff) == 0) failed.") + end subroutine test_cross_product_int8 + + subroutine test_cross_product_int16(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + integer(int16) :: u(n), v(n), expected(n), diff(n) + + u = [1,0,0] + v = [0,1,0] + expected = [0,0,1] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) == 0), & + "all(abs(diff) == 0) failed.") + end subroutine test_cross_product_int16 + + subroutine test_cross_product_int32(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + integer(int32) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_int32" + u = [1,0,0] + v = [0,1,0] + expected = [0,0,1] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) == 0), & + "all(abs(diff) == 0) failed.") + end subroutine test_cross_product_int32 + + subroutine test_cross_product_int64(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + integer(int64) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_int64" + u = [1,0,0] + v = [0,1,0] + expected = [0,0,1] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) == 0), & + "all(abs(diff) == 0) failed.") + end subroutine test_cross_product_int64 + + subroutine test_cross_product_rsp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + real(sp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_rsp" + u = [1.1_sp,2.5_sp,2.4_sp] + v = [0.5_sp,1.5_sp,2.5_sp] + expected = [2.65_sp,-1.55_sp,0.4_sp] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < sptol), & + "all(abs(diff) < sptol) failed.") + end subroutine test_cross_product_rsp + + subroutine test_cross_product_rdp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + real(dp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_rdp" + u = [1.1_dp,2.5_dp,2.4_dp] + v = [0.5_dp,1.5_dp,2.5_dp] + expected = [2.65_dp,-1.55_dp,0.4_dp] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < dptol), & + "all(abs(diff) < dptol) failed.") + end subroutine test_cross_product_rdp + + subroutine test_cross_product_rqp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + +#:if WITH_QP + integer, parameter :: n = 3 + real(qp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_rqp" + u = [1.1_qp,2.5_qp,2.4_qp] + v = [0.5_qp,1.5_qp,2.5_qp] + expected = [2.65_qp,-1.55_qp,0.4_qp] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < qptol), & + "all(abs(diff) < qptol) failed.") +#:else + call skip_test(error, "Quadruple precision is not enabled") +#:endif + end subroutine test_cross_product_rqp + + subroutine test_cross_product_csp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + complex(sp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_csp" + u = [cmplx(0,1,sp),cmplx(1,0,sp),cmplx(0,0,sp)] + v = [cmplx(1,1,sp),cmplx(0,0,sp),cmplx(1,0,sp)] + expected = [cmplx(1,0,sp),cmplx(0,-1,sp),cmplx(-1,-1,sp)] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < sptol), & + "all(abs(diff) < sptol) failed.") + end subroutine test_cross_product_csp + + subroutine test_cross_product_cdp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + + integer, parameter :: n = 3 + complex(dp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_cdp" + u = [cmplx(0,1,dp),cmplx(1,0,dp),cmplx(0,0,dp)] + v = [cmplx(1,1,dp),cmplx(0,0,dp),cmplx(1,0,dp)] + expected = [cmplx(1,0,dp),cmplx(0,-1,dp),cmplx(-1,-1,dp)] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < dptol), & + "all(abs(diff) < dptol) failed.") + end subroutine test_cross_product_cdp + + subroutine test_cross_product_cqp(error) + !> Error handling + type(error_type), allocatable, intent(out) :: error + +#:if WITH_QP + integer, parameter :: n = 3 + complex(qp) :: u(n), v(n), expected(n), diff(n) + write(*,*) "test_cross_product_cqp" + u = [cmplx(0,1,qp),cmplx(1,0,qp),cmplx(0,0,qp)] + v = [cmplx(1,1,qp),cmplx(0,0,qp),cmplx(1,0,qp)] + expected = [cmplx(1,0,qp),cmplx(0,-1,qp),cmplx(-1,-1,qp)] + diff = expected - cross_product(u,v) + call check(error, all(abs(diff) < qptol), & + "all(abs(diff) < qptol) failed.") +#:else + call skip_test(error, "Quadruple precision is not enabled") +#:endif + end subroutine test_cross_product_cqp pure recursive function catalan_number(n) result(value) integer, intent(in) :: n From e3e111c24df6cae8b3e8a6c820bc3149c4fd28fd Mon Sep 17 00:00:00 2001 From: St Maxwell Date: Tue, 22 Nov 2022 23:34:24 +0800 Subject: [PATCH 2/4] fix test_linalg --- test/linalg/test_linalg.fypp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/test/linalg/test_linalg.fypp b/test/linalg/test_linalg.fypp index f3bbcdf2c..2fd99ce7e 100644 --- a/test/linalg/test_linalg.fypp +++ b/test/linalg/test_linalg.fypp @@ -57,7 +57,17 @@ contains new_unittest("outer_product_int8", test_outer_product_int8), & new_unittest("outer_product_int16", test_outer_product_int16), & new_unittest("outer_product_int32", test_outer_product_int32), & - new_unittest("outer_product_int64", test_outer_product_int64) & + new_unittest("outer_product_int64", test_outer_product_int64), & + new_unittest("cross_product_rsp", test_cross_product_rsp), & + new_unittest("cross_product_rdp", test_cross_product_rdp), & + new_unittest("cross_product_rqp", test_cross_product_rqp), & + new_unittest("cross_product_csp", test_cross_product_csp), & + new_unittest("cross_product_cdp", test_cross_product_cdp), & + new_unittest("cross_product_cqp", test_cross_product_cqp), & + new_unittest("cross_product_int8", test_cross_product_int8), & + new_unittest("cross_product_int16", test_cross_product_int16), & + new_unittest("cross_product_int32", test_cross_product_int32), & + new_unittest("cross_product_int64", test_cross_product_int64) & ] end subroutine collect_linalg From 07e3dfa30a1177787263aae8ac664ac94898d918 Mon Sep 17 00:00:00 2001 From: St Maxwell Date: Tue, 22 Nov 2022 23:42:01 +0800 Subject: [PATCH 3/4] fix cmakelist --- src/CMakeLists.txt | 1 + src/stdlib_linalg_cross_product.fypp | 2 +- test/linalg/test_linalg.fypp | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6f1fd0a18..8f512af56 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,6 +22,7 @@ set(fppFiles stdlib_linalg.fypp stdlib_linalg_diag.fypp stdlib_linalg_outer_product.fypp + stdlib_linalg_cross_product.fypp stdlib_optval.fypp stdlib_selection.fypp stdlib_sorting.fypp diff --git a/src/stdlib_linalg_cross_product.fypp b/src/stdlib_linalg_cross_product.fypp index bc0afa0a0..46d9e736a 100644 --- a/src/stdlib_linalg_cross_product.fypp +++ b/src/stdlib_linalg_cross_product.fypp @@ -1,7 +1,7 @@ #:include "common.fypp" #:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES submodule (stdlib_linalg) stdlib_linalg_cross_product - use stdlib_error, only: error_stop + implicit none contains diff --git a/test/linalg/test_linalg.fypp b/test/linalg/test_linalg.fypp index 2fd99ce7e..de42004ed 100644 --- a/test/linalg/test_linalg.fypp +++ b/test/linalg/test_linalg.fypp @@ -3,7 +3,7 @@ module test_linalg use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64 - use stdlib_linalg, only: diag, eye, trace, outer_product,cross_product + use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product implicit none From cb31a95c268c90111daa561b57e1054efbf8190c Mon Sep 17 00:00:00 2001 From: St Maxwell Date: Tue, 6 Dec 2022 11:44:52 +0800 Subject: [PATCH 4/4] more explicit test fail message --- test/linalg/test_linalg.fypp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/linalg/test_linalg.fypp b/test/linalg/test_linalg.fypp index de42004ed..2ffd2d7de 100644 --- a/test/linalg/test_linalg.fypp +++ b/test/linalg/test_linalg.fypp @@ -724,7 +724,7 @@ contains expected = [0,0,1] diff = expected - cross_product(u,v) call check(error, all(abs(diff) == 0), & - "all(abs(diff) == 0) failed.") + "cross_product(u,v) == expected failed.") end subroutine test_cross_product_int8 subroutine test_cross_product_int16(error) @@ -739,7 +739,7 @@ contains expected = [0,0,1] diff = expected - cross_product(u,v) call check(error, all(abs(diff) == 0), & - "all(abs(diff) == 0) failed.") + "cross_product(u,v) == expected failed.") end subroutine test_cross_product_int16 subroutine test_cross_product_int32(error) @@ -754,7 +754,7 @@ contains expected = [0,0,1] diff = expected - cross_product(u,v) call check(error, all(abs(diff) == 0), & - "all(abs(diff) == 0) failed.") + "cross_product(u,v) == expected failed.") end subroutine test_cross_product_int32 subroutine test_cross_product_int64(error) @@ -769,7 +769,7 @@ contains expected = [0,0,1] diff = expected - cross_product(u,v) call check(error, all(abs(diff) == 0), & - "all(abs(diff) == 0) failed.") + "cross_product(u,v) == expected failed.") end subroutine test_cross_product_int64 subroutine test_cross_product_rsp(error) @@ -784,7 +784,7 @@ contains expected = [2.65_sp,-1.55_sp,0.4_sp] diff = expected - cross_product(u,v) call check(error, all(abs(diff) < sptol), & - "all(abs(diff) < sptol) failed.") + "all(abs(cross_product(u,v)-expected)) < sptol failed.") end subroutine test_cross_product_rsp subroutine test_cross_product_rdp(error) @@ -799,7 +799,7 @@ contains expected = [2.65_dp,-1.55_dp,0.4_dp] diff = expected - cross_product(u,v) call check(error, all(abs(diff) < dptol), & - "all(abs(diff) < dptol) failed.") + "all(abs(cross_product(u,v)-expected)) < dptol failed.") end subroutine test_cross_product_rdp subroutine test_cross_product_rqp(error) @@ -815,7 +815,7 @@ contains expected = [2.65_qp,-1.55_qp,0.4_qp] diff = expected - cross_product(u,v) call check(error, all(abs(diff) < qptol), & - "all(abs(diff) < qptol) failed.") + "all(abs(cross_product(u,v)-expected)) < qptol failed.") #:else call skip_test(error, "Quadruple precision is not enabled") #:endif @@ -833,7 +833,7 @@ contains expected = [cmplx(1,0,sp),cmplx(0,-1,sp),cmplx(-1,-1,sp)] diff = expected - cross_product(u,v) call check(error, all(abs(diff) < sptol), & - "all(abs(diff) < sptol) failed.") + "all(abs(cross_product(u,v)-expected)) < sptol failed.") end subroutine test_cross_product_csp subroutine test_cross_product_cdp(error) @@ -848,7 +848,7 @@ contains expected = [cmplx(1,0,dp),cmplx(0,-1,dp),cmplx(-1,-1,dp)] diff = expected - cross_product(u,v) call check(error, all(abs(diff) < dptol), & - "all(abs(diff) < dptol) failed.") + "all(abs(cross_product(u,v)-expected)) < dptol failed.") end subroutine test_cross_product_cdp subroutine test_cross_product_cqp(error) @@ -864,7 +864,7 @@ contains expected = [cmplx(1,0,qp),cmplx(0,-1,qp),cmplx(-1,-1,qp)] diff = expected - cross_product(u,v) call check(error, all(abs(diff) < qptol), & - "all(abs(diff) < qptol) failed.") + "all(abs(cross_product(u,v)-expected)) < qptol failed.") #:else call skip_test(error, "Quadruple precision is not enabled") #:endif