Skip to content

cross product of two vectors #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions doc/specs/stdlib_linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,36 @@ Returns a rank-2 array equal to `u v^T` (where `u, v` are considered column vect
{!example/linalg/example_outer_product.f90!}
```

## `cross_product` - Computes the cross product of two vectors

### Status

Experimental

### Description

Computes the cross product of two vectors

### Syntax

`c = [[stdlib_linalg(module):cross_product(interface)]](a, b)`

### Arguments

`a`: Shall be a rank-1 and size-3 array

`b`: Shall be a rank-1 and size-3 array

### Return value

Returns a rank-1 and size-3 array which is perpendicular to both `a` and `b`.

### Example

```fortran
{!example/linalg/example_cross_product.f90!}
```

## `is_square` - Checks if a matrix is square

### Status
Expand Down
9 changes: 9 additions & 0 deletions example/linalg/example_cross_product.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
program demo_cross_product
use stdlib_linalg, only: cross_product
implicit none
real :: a(3), b(3), c(3)
a = [1., 0., 0.]
b = [0., 1., 0.]
c = cross_product(a, b)
!c = [0., 0., 1.]
end program demo_cross_product
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(fppFiles
stdlib_linalg.fypp
stdlib_linalg_diag.fypp
stdlib_linalg_outer_product.fypp
stdlib_linalg_cross_product.fypp
stdlib_optval.fypp
stdlib_selection.fypp
stdlib_sorting.fypp
Expand Down
16 changes: 16 additions & 0 deletions src/stdlib_linalg.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module stdlib_linalg
public :: eye
public :: trace
public :: outer_product
public :: cross_product
public :: is_square
public :: is_diagonal
public :: is_symmetric
Expand Down Expand Up @@ -93,6 +94,21 @@ module stdlib_linalg
end interface outer_product


! Cross product (of two vectors)
interface cross_product
!! version: experimental
!!
!! Computes the cross product of two vectors, returning a rank-1 and size-3 array
!! ([Specification](../page/specs/stdlib_linalg.html#cross_product-computes-the-cross-product-of-two-3-d-vectors))
#:for k1, t1 in RCI_KINDS_TYPES
pure module function cross_product_${t1[0]}$${k1}$(a, b) result(res)
${t1}$, intent(in) :: a(3), b(3)
${t1}$ :: res(3)
end function cross_product_${t1[0]}$${k1}$
#:endfor
end interface cross_product


! Check for squareness
interface is_square
!! version: experimental
Expand Down
21 changes: 21 additions & 0 deletions src/stdlib_linalg_cross_product.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#:include "common.fypp"
#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES
submodule (stdlib_linalg) stdlib_linalg_cross_product

implicit none

contains

#:for k1, t1 in RCI_KINDS_TYPES
pure module function cross_product_${t1[0]}$${k1}$(a, b) result(res)
${t1}$, intent(in) :: a(3), b(3)
${t1}$ :: res(3)

res(1) = a(2) * b(3) - a(3) * b(2)
res(2) = a(3) * b(1) - a(1) * b(3)
res(3) = a(1) * b(2) - a(2) * b(1)

end function cross_product_${t1[0]}$${k1}$
#:endfor

end submodule
171 changes: 169 additions & 2 deletions test/linalg/test_linalg.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module test_linalg
use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test
use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64
use stdlib_linalg, only: diag, eye, trace, outer_product
use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product

implicit none

Expand Down Expand Up @@ -57,7 +57,17 @@ contains
new_unittest("outer_product_int8", test_outer_product_int8), &
new_unittest("outer_product_int16", test_outer_product_int16), &
new_unittest("outer_product_int32", test_outer_product_int32), &
new_unittest("outer_product_int64", test_outer_product_int64) &
new_unittest("outer_product_int64", test_outer_product_int64), &
new_unittest("cross_product_rsp", test_cross_product_rsp), &
new_unittest("cross_product_rdp", test_cross_product_rdp), &
new_unittest("cross_product_rqp", test_cross_product_rqp), &
new_unittest("cross_product_csp", test_cross_product_csp), &
new_unittest("cross_product_cdp", test_cross_product_cdp), &
new_unittest("cross_product_cqp", test_cross_product_cqp), &
new_unittest("cross_product_int8", test_cross_product_int8), &
new_unittest("cross_product_int16", test_cross_product_int16), &
new_unittest("cross_product_int32", test_cross_product_int32), &
new_unittest("cross_product_int64", test_cross_product_int64) &
]

end subroutine collect_linalg
Expand Down Expand Up @@ -702,6 +712,163 @@ contains
"all(abs(diff) == 0) failed.")
end subroutine test_outer_product_int64

subroutine test_cross_product_int8(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

integer, parameter :: n = 3
integer(int8) :: u(n), v(n), expected(n), diff(n)

u = [1,0,0]
v = [0,1,0]
expected = [0,0,1]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) == 0), &
"cross_product(u,v) == expected failed.")
end subroutine test_cross_product_int8

subroutine test_cross_product_int16(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

integer, parameter :: n = 3
integer(int16) :: u(n), v(n), expected(n), diff(n)

u = [1,0,0]
v = [0,1,0]
expected = [0,0,1]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) == 0), &
"cross_product(u,v) == expected failed.")
end subroutine test_cross_product_int16

subroutine test_cross_product_int32(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

integer, parameter :: n = 3
integer(int32) :: u(n), v(n), expected(n), diff(n)
write(*,*) "test_cross_product_int32"
u = [1,0,0]
v = [0,1,0]
expected = [0,0,1]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) == 0), &
"cross_product(u,v) == expected failed.")
end subroutine test_cross_product_int32

subroutine test_cross_product_int64(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

integer, parameter :: n = 3
integer(int64) :: u(n), v(n), expected(n), diff(n)
write(*,*) "test_cross_product_int64"
u = [1,0,0]
v = [0,1,0]
expected = [0,0,1]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) == 0), &
"cross_product(u,v) == expected failed.")
end subroutine test_cross_product_int64

subroutine test_cross_product_rsp(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

integer, parameter :: n = 3
real(sp) :: u(n), v(n), expected(n), diff(n)
write(*,*) "test_cross_product_rsp"
u = [1.1_sp,2.5_sp,2.4_sp]
v = [0.5_sp,1.5_sp,2.5_sp]
expected = [2.65_sp,-1.55_sp,0.4_sp]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) < sptol), &
"all(abs(cross_product(u,v)-expected)) < sptol failed.")
end subroutine test_cross_product_rsp

subroutine test_cross_product_rdp(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

integer, parameter :: n = 3
real(dp) :: u(n), v(n), expected(n), diff(n)
write(*,*) "test_cross_product_rdp"
u = [1.1_dp,2.5_dp,2.4_dp]
v = [0.5_dp,1.5_dp,2.5_dp]
expected = [2.65_dp,-1.55_dp,0.4_dp]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) < dptol), &
"all(abs(cross_product(u,v)-expected)) < dptol failed.")
end subroutine test_cross_product_rdp

subroutine test_cross_product_rqp(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

#:if WITH_QP
integer, parameter :: n = 3
real(qp) :: u(n), v(n), expected(n), diff(n)
write(*,*) "test_cross_product_rqp"
u = [1.1_qp,2.5_qp,2.4_qp]
v = [0.5_qp,1.5_qp,2.5_qp]
expected = [2.65_qp,-1.55_qp,0.4_qp]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) < qptol), &
"all(abs(cross_product(u,v)-expected)) < qptol failed.")
#:else
call skip_test(error, "Quadruple precision is not enabled")
#:endif
end subroutine test_cross_product_rqp

subroutine test_cross_product_csp(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

integer, parameter :: n = 3
complex(sp) :: u(n), v(n), expected(n), diff(n)
write(*,*) "test_cross_product_csp"
u = [cmplx(0,1,sp),cmplx(1,0,sp),cmplx(0,0,sp)]
v = [cmplx(1,1,sp),cmplx(0,0,sp),cmplx(1,0,sp)]
expected = [cmplx(1,0,sp),cmplx(0,-1,sp),cmplx(-1,-1,sp)]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) < sptol), &
"all(abs(cross_product(u,v)-expected)) < sptol failed.")
end subroutine test_cross_product_csp

subroutine test_cross_product_cdp(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

integer, parameter :: n = 3
complex(dp) :: u(n), v(n), expected(n), diff(n)
write(*,*) "test_cross_product_cdp"
u = [cmplx(0,1,dp),cmplx(1,0,dp),cmplx(0,0,dp)]
v = [cmplx(1,1,dp),cmplx(0,0,dp),cmplx(1,0,dp)]
expected = [cmplx(1,0,dp),cmplx(0,-1,dp),cmplx(-1,-1,dp)]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) < dptol), &
"all(abs(cross_product(u,v)-expected)) < dptol failed.")
end subroutine test_cross_product_cdp

subroutine test_cross_product_cqp(error)
!> Error handling
type(error_type), allocatable, intent(out) :: error

#:if WITH_QP
integer, parameter :: n = 3
complex(qp) :: u(n), v(n), expected(n), diff(n)
write(*,*) "test_cross_product_cqp"
u = [cmplx(0,1,qp),cmplx(1,0,qp),cmplx(0,0,qp)]
v = [cmplx(1,1,qp),cmplx(0,0,qp),cmplx(1,0,qp)]
expected = [cmplx(1,0,qp),cmplx(0,-1,qp),cmplx(-1,-1,qp)]
diff = expected - cross_product(u,v)
call check(error, all(abs(diff) < qptol), &
"all(abs(cross_product(u,v)-expected)) < qptol failed.")
#:else
call skip_test(error, "Quadruple precision is not enabled")
#:endif
end subroutine test_cross_product_cqp

pure recursive function catalan_number(n) result(value)
integer, intent(in) :: n
Expand Down