Skip to content

Commit ebc61cd

Browse files
committed
Update eye function.
1 parent ce3a106 commit ebc61cd

File tree

5 files changed

+50
-18
lines changed

5 files changed

+50
-18
lines changed

doc/specs/stdlib_linalg.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,29 @@ end program demo_diag5
101101

102102
Experimental
103103

104+
### Class
105+
106+
Pure function.
107+
104108
### Description
105109

106-
Construct the identity matrix
110+
Construct the identity matrix.
107111

108112
### Syntax
109113

110-
`I = [[stdlib_linalg(module):eye(function)]](n)`
114+
`I = [[stdlib_linalg(module):eye(function)]](dim1 [, dim2])`
111115

112116
### Arguments
113117

114-
`n`: Shall be a scalar of default type `integer`.
118+
`dim1`: Shall be a scalar of default type `integer`.
119+
This is an `intent(in)` argument.
120+
121+
`dim2`: Shall be a scalar of default type `integer`.
122+
This is an `intent(in)` and `optional` argument.
115123

116124
### Return value
117125

118-
Returns the identity matrix, i.e. a square matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`.
126+
Returns the identity matrix, i.e. a square matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer`.
119127

120128
### Example
121129

@@ -124,7 +132,10 @@ program demo_eye1
124132
use stdlib_linalg, only: eye
125133
implicit none
126134
real :: a(3,3)
135+
real :: b(3,4)
127136
A = eye(3)
137+
A = eye(3,3)
138+
B = eye(3,4)
128139
end program demo_eye1
129140
```
130141

src/stdlib_linalg.fypp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,28 @@ module stdlib_linalg
8282

8383
contains
8484

85-
function eye(n) result(res)
86-
!! version: experimental
87-
!!
88-
!! Constructs the identity matrix
89-
!! ([Specification](../page/specs/stdlib_linalg.html#description_1))
90-
integer, intent(in) :: n
91-
integer(int8) :: res(n, n)
92-
integer :: i
93-
res = 0
94-
do i = 1, n
95-
res(i, i) = 1
96-
end do
97-
end function eye
85+
!> Version: experimental
86+
!>
87+
!> Constructs the identity matrix.
88+
!> ([Specification](../page/specs/stdlib_linalg.html#eye-construct-the-identity-matrix))
89+
pure function eye(dim1, dim2) result(result)
90+
91+
integer, intent(in) :: dim1
92+
integer, intent(in), optional :: dim2
93+
integer, allocatable :: result(:, :)
94+
95+
integer :: dim2_
96+
integer :: i
9897

98+
dim2_ = merge(dim2, dim1, present(dim2))
99+
allocate(result(dim1, dim2_))
100+
101+
result = 0
102+
do i = 1, min(dim1, dim2_)
103+
result(i, i) = 1
104+
end do
105+
106+
end function eye
99107

100108
#:for k1, t1 in RCI_KINDS_TYPES
101109
function trace_${t1[0]}$${k1}$(A) result(res)
@@ -108,4 +116,5 @@ contains
108116
end do
109117
end function trace_${t1[0]}$${k1}$
110118
#:endfor
111-
end module
119+
120+
end module stdlib_linalg

src/tests/Makefile.manual

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ all test clean:
1111
$(MAKE) -f Makefile.manual --directory=stats $@
1212
$(MAKE) -f Makefile.manual --directory=string $@
1313
$(MAKE) -f Makefile.manual --directory=math $@
14+
$(MAKE) -f Makefile.manual --directory=linalg $@

src/tests/linalg/Makefile.manual

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
PROGS_SRC = test_linalg.f90
2+
3+
4+
include ../Makefile.manual.test.mk

src/tests/linalg/test_linalg.f90

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ subroutine test_eye
8181
integer :: i
8282
write(*,*) "test_eye"
8383

84+
call check(all(eye(3,3) == diag([(1,i=1,3)])), &
85+
msg="all(eye(3,3) == diag([(1,i=1,3)])) failed.",warn=warn)
86+
87+
rye = eye(3,4)
88+
call check(sum(rye(:,1:3) - diag([(1.0_sp,i=1,3)])) < sptol, &
89+
msg="sum(rye(:,1:3) - diag([(1.0_sp,i=1,3)])) < sptol", warn=warn)
90+
8491
call check(all(eye(5) == diag([(1,i=1,5)])), &
8592
msg="all(eye(5) == diag([(1,i=1,5)] failed.",warn=warn)
8693

0 commit comments

Comments
 (0)