Skip to content

Commit 1632355

Browse files
committed
corr_dev:completed impl
1 parent 1fe73b1 commit 1632355

File tree

3 files changed

+208
-160
lines changed

3 files changed

+208
-160
lines changed

src/stdlib_experimental_stats.fypp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,30 @@ module stdlib_experimental_stats
7777
, merge(size(x, 1), size(x, 2), mask = 1<dim))
7878
end function ${RName}$
7979
#:endfor
80+
81+
82+
#:for k1, t1 in RC_KINDS_TYPES
83+
#:set RName = rname("corr_mask",2, t1, k1)
84+
module function ${RName}$(x, dim, mask) result(res)
85+
${t1}$, intent(in) :: x(:, :)
86+
integer, intent(in) :: dim
87+
logical, intent(in) :: mask(:,:)
88+
${t1}$ :: res(merge(size(x, 1), size(x, 2), mask = 1<dim)&
89+
, merge(size(x, 1), size(x, 2), mask = 1<dim))
90+
end function ${RName}$
91+
#:endfor
92+
93+
#:for k1, t1 in INT_KINDS_TYPES
94+
#:set RName = rname("corr_mask",2, t1, k1, 'dp')
95+
module function ${RName}$(x, dim, mask) result(res)
96+
${t1}$, intent(in) :: x(:, :)
97+
integer, intent(in) :: dim
98+
logical, intent(in) :: mask(:,:)
99+
real(dp) :: res(merge(size(x, 1), size(x, 2), mask = 1<dim)&
100+
, merge(size(x, 1), size(x, 2), mask = 1<dim))
101+
end function ${RName}$
102+
#:endfor
103+
80104
end interface corr
81105

82106

src/stdlib_experimental_stats_corr.fypp

Lines changed: 145 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ contains
1818
logical, intent(in), optional :: mask
1919
real(${k1}$) :: res
2020

21-
if (.not.optval(mask, .true.)) then
21+
if (.not.optval(mask, .true.) .or. size(x) < 2) then
2222
res = ieee_value(1._${k1}$, ieee_quiet_nan)
2323
return
2424
end if
@@ -37,7 +37,7 @@ contains
3737
logical, intent(in), optional :: mask
3838
real(dp) :: res
3939

40-
if (.not.optval(mask, .true.)) then
40+
if (.not.optval(mask, .true.) .or. size(x) < 2) then
4141
res = ieee_value(1._dp, ieee_quiet_nan)
4242
return
4343
end if
@@ -56,7 +56,7 @@ contains
5656
logical, intent(in) :: mask(:)
5757
real(${k1}$) :: res
5858

59-
if (all(.not.mask)) then
59+
if (count(mask) < 2) then
6060
res = ieee_value(1._${k1}$, ieee_quiet_nan)
6161
return
6262
end if
@@ -75,7 +75,7 @@ contains
7575
logical, intent(in) :: mask(:)
7676
real(dp) :: res
7777

78-
if (all(.not.mask)) then
78+
if (count(mask) < 2) then
7979
res = ieee_value(1._dp, ieee_quiet_nan)
8080
return
8181
end if
@@ -99,7 +99,7 @@ contains
9999
${t1}$ :: mean_(merge(size(x, 1), size(x, 2), mask = 1<dim))
100100
${t1}$ :: center(size(x, 1),size(x, 2))
101101

102-
if (.not.optval(mask, .true.)) then
102+
if (.not.optval(mask, .true.) .or. size(x) < 2) then
103103
res = ieee_value(1._${k1}$, ieee_quiet_nan)
104104
return
105105
end if
@@ -127,7 +127,7 @@ contains
127127
case default
128128
call error_stop("ERROR (corr): wrong dimension")
129129
end select
130-
130+
131131
mean_ = 1 / sqrt(diag(res))
132132
do i = 1, size(res, 1)
133133
do j = 1, size(res, 2)
@@ -152,7 +152,7 @@ contains
152152
real(dp) :: mean_(merge(size(x, 1), size(x, 2), mask = 1<dim))
153153
real(dp) :: center(size(x, 1),size(x, 2))
154154

155-
if (.not.optval(mask, .true.)) then
155+
if (.not.optval(mask, .true.) .or. size(x) < 2) then
156156
res = ieee_value(1._dp, ieee_quiet_nan)
157157
return
158158
end if
@@ -184,124 +184,144 @@ contains
184184
#:endfor
185185

186186

187-
! #:for k1, t1 in RC_KINDS_TYPES
188-
! #:set RName = rname("corr_mask",2, t1, k1)
189-
! module function ${RName}$(x, dim, mask, corrected) result(res)
190-
! ${t1}$, intent(in) :: x(:, :)
191-
! integer, intent(in) :: dim
192-
! logical, intent(in) :: mask(:,:)
193-
! logical, intent(in), optional :: corrected
194-
! ${t1}$ :: res(merge(size(x, 1), size(x, 2), mask = 1<dim)&
195-
! , merge(size(x, 1), size(x, 2), mask = 1<dim))
196-
!
197-
! integer :: i, j, n
198-
! ${t1}$ :: mean_(merge(size(x, 1), size(x, 2), mask = 1<dim))
199-
! ${t1}$ :: center(size(x, 1),size(x, 2))
200-
!
201-
! mean_ = mean(x, dim, mask = mask)
202-
! select case(dim)
203-
! case(1)
204-
! do i = 1, size(x, 1)
205-
! center(i, :) = merge( x(i, :) - mean_,&
206-
! #:if t1[0] == 'r'
207-
! 0._${k1}$,&
208-
! #:else
209-
! cmplx(0,0,kind=${k1}$),&
210-
! #:endif
211-
! mask(i, :))
212-
! end do
213-
! #:if t1[0] == 'r'
214-
! res = matmul( transpose(center), center)
215-
! #:else
216-
! res = matmul( transpose(conjg(center)), center)
217-
! #:endif
218-
! do j = 1, size(res, 2)
219-
! do i = 1, size(res, 1)
220-
! n = count(merge(.true., .false., mask(:, i) .and. mask(:, j)))
221-
! res(i, j) = res(i, j) / (n - merge(1, 0,&
222-
! optval(corrected, .true.) .and. n > 0))
223-
! end do
224-
! end do
225-
! case(2)
226-
! do i = 1, size(x, 2)
227-
! center(:, i) = merge( x(:, i) - mean_,&
228-
! #:if t1[0] == 'r'
229-
! 0._${k1}$,&
230-
! #:else
231-
! cmplx(0,0,kind=${k1}$),&
232-
! #:endif
233-
! mask(:, i))
234-
! end do
235-
! #:if t1[0] == 'r'
236-
! res = matmul( center, transpose(center))
237-
! #:else
238-
! res = matmul( center, transpose(conjg(center)))
239-
! #:endif
240-
! do j = 1, size(res, 2)
241-
! do i = 1, size(res, 1)
242-
! n = count(merge(.true., .false., mask(i, :) .and. mask(j, :)))
243-
! res(i, j) = res(i, j) / (n - merge(1, 0,&
244-
! optval(corrected, .true.) .and. n > 0))
245-
! end do
246-
! end do
247-
! case default
248-
! call error_stop("ERROR (corr): wrong dimension")
249-
! end select
250-
!
251-
! end function ${RName}$
252-
! #:endfor
253-
!
254-
!
255-
! #:for k1, t1 in INT_KINDS_TYPES
256-
! #:set RName = rname("corr_mask",2, t1, k1, 'dp')
257-
! module function ${RName}$(x, dim, mask, corrected) result(res)
258-
! ${t1}$, intent(in) :: x(:, :)
259-
! integer, intent(in) :: dim
260-
! logical, intent(in) :: mask(:,:)
261-
! logical, intent(in), optional :: corrected
262-
! real(dp) :: res(merge(size(x, 1), size(x, 2), mask = 1<dim)&
263-
! , merge(size(x, 1), size(x, 2), mask = 1<dim))
264-
!
265-
! integer :: i, j, n
266-
! real(dp) :: mean_(merge(size(x, 1), size(x, 2), mask = 1<dim))
267-
! real(dp) :: center(size(x, 1),size(x, 2))
268-
!
269-
! mean_ = mean(x, dim, mask = mask)
270-
! select case(dim)
271-
! case(1)
272-
! do i = 1, size(x, 1)
273-
! center(i, :) = merge( x(i, :) - mean_,&
274-
! 0._dp,&
275-
! mask(i, :))
276-
! end do
277-
! res = matmul( transpose(center), center)
278-
! do j = 1, size(res, 2)
279-
! do i = 1, size(res, 1)
280-
! n = count(merge(.true., .false., mask(:, i) .and. mask(:, j)))
281-
! res(i, j) = res(i, j) / (n - merge(1, 0,&
282-
! optval(corrected, .true.) .and. n > 0))
283-
! end do
284-
! end do
285-
! case(2)
286-
! do i = 1, size(x, 2)
287-
! center(:, i) = merge( x(:, i) - mean_,&
288-
! 0._dp,&
289-
! mask(:, i))
290-
! end do
291-
! res = matmul( center, transpose(center))
292-
! do j = 1, size(res, 2)
293-
! do i = 1, size(res, 1)
294-
! n = count(merge(.true., .false., mask(i, :) .and. mask(j, :)))
295-
! res(i, j) = res(i, j) / (n - merge(1, 0,&
296-
! optval(corrected, .true.) .and. n > 0))
297-
! end do
298-
! end do
299-
! case default
300-
! call error_stop("ERROR (corr): wrong dimension")
301-
! end select
302-
!
303-
! end function ${RName}$
304-
! #:endfor
187+
#:for k1, t1 in RC_KINDS_TYPES
188+
#:set RName = rname("corr_mask",2, t1, k1)
189+
module function ${RName}$(x, dim, mask) result(res)
190+
${t1}$, intent(in) :: x(:, :)
191+
integer, intent(in) :: dim
192+
logical, intent(in) :: mask(:,:)
193+
${t1}$ :: res(merge(size(x, 1), size(x, 2), mask = 1<dim)&
194+
, merge(size(x, 1), size(x, 2), mask = 1<dim))
195+
196+
integer :: i, j
197+
${t1}$ :: centeri_(merge(size(x, 2), size(x, 1), mask = 1<dim))
198+
${t1}$ :: centerj_(merge(size(x, 2), size(x, 1), mask = 1<dim))
199+
logical :: mask_(merge(size(x, 2), size(x, 1), mask = 1<dim))
200+
201+
select case(dim)
202+
case(1)
203+
do i = 1, size(x, 2)
204+
do j = 1, size(x, 2)
205+
mask_ = merge(.true., .false., mask(:, i) .and. mask(:, j))
206+
centeri_ = merge( x(:, i) - mean(x(:, i), mask = mask_),&
207+
#:if t1[0] == 'r'
208+
0._${k1}$,&
209+
#:else
210+
cmplx(0,0,kind=${k1}$),&
211+
#:endif
212+
mask_)
213+
centerj_ = merge( x(:, j) - mean(x(:, j), mask = mask_),&
214+
#:if t1[0] == 'r'
215+
0._${k1}$,&
216+
#:else
217+
cmplx(0,0,kind=${k1}$),&
218+
#:endif
219+
mask_)
220+
221+
#:if t1[0] == 'r'
222+
res(j, i) = dot_product( centerj_, centeri_)&
223+
/sqrt(dot_product( centeri_, centeri_)*&
224+
dot_product( centerj_, centerj_))
225+
#:else
226+
res(j, i) = dot_product( (conjg(centerj_)), centeri_)&
227+
/sqrt(dot_product( (conjg(centeri_)), centeri_)*&
228+
dot_product( (conjg(centerj_)), centerj_))
229+
#:endif
230+
231+
end do
232+
end do
233+
case(2)
234+
do i = 1, size(x, 1)
235+
do j = 1, size(x, 1)
236+
mask_ = merge(.true., .false., mask(i, :) .and. mask(j, :))
237+
centeri_ = merge( x(i, :) - mean(x(i, :), mask = mask_),&
238+
#:if t1[0] == 'r'
239+
0._${k1}$,&
240+
#:else
241+
cmplx(0,0,kind=${k1}$),&
242+
#:endif
243+
mask_)
244+
centerj_ = merge( x(j, :) - mean(x(j, :), mask = mask_),&
245+
#:if t1[0] == 'r'
246+
0._${k1}$,&
247+
#:else
248+
cmplx(0,0,kind=${k1}$),&
249+
#:endif
250+
mask_)
251+
252+
#:if t1[0] == 'r'
253+
res(j, i) = dot_product( centerj_, centeri_)&
254+
/sqrt(dot_product( centeri_, centeri_)*&
255+
dot_product( centerj_, centerj_))
256+
#:else
257+
res(j, i) = dot_product( (conjg(centerj_)), centeri_)&
258+
/sqrt(dot_product( (conjg(centeri_)), centeri_)*&
259+
dot_product( (conjg(centerj_)), centerj_))
260+
#:endif
261+
end do
262+
end do
263+
case default
264+
call error_stop("ERROR (corr): wrong dimension")
265+
end select
266+
267+
end function ${RName}$
268+
#:endfor
269+
270+
271+
#:for k1, t1 in INT_KINDS_TYPES
272+
#:set RName = rname("corr_mask",2, t1, k1, 'dp')
273+
module function ${RName}$(x, dim, mask) result(res)
274+
${t1}$, intent(in) :: x(:, :)
275+
integer, intent(in) :: dim
276+
logical, intent(in) :: mask(:,:)
277+
real(dp) :: res(merge(size(x, 1), size(x, 2), mask = 1<dim)&
278+
, merge(size(x, 1), size(x, 2), mask = 1<dim))
279+
280+
integer :: i, j
281+
real(dp) :: centeri_(merge(size(x, 2), size(x, 1), mask = 1<dim))
282+
real(dp) :: centerj_(merge(size(x, 2), size(x, 1), mask = 1<dim))
283+
logical :: mask_(merge(size(x, 2), size(x, 1), mask = 1<dim))
284+
285+
select case(dim)
286+
case(1)
287+
do i = 1, size(x, 2)
288+
do j = 1, size(x, 2)
289+
mask_ = merge(.true., .false., mask(:, i) .and. mask(:, j))
290+
centeri_ = merge( x(:, i) - mean(x(:, i), mask = mask_),&
291+
0._dp,&
292+
mask_)
293+
centerj_ = merge( x(:, j) - mean(x(:, j), mask = mask_),&
294+
0._dp,&
295+
mask_)
296+
297+
res(j, i) = dot_product( centerj_, centeri_)&
298+
/sqrt(dot_product( centeri_, centeri_)*&
299+
dot_product( centerj_, centerj_))
300+
301+
end do
302+
end do
303+
case(2)
304+
do i = 1, size(x, 1)
305+
do j = 1, size(x, 1)
306+
mask_ = merge(.true., .false., mask(i, :) .and. mask(j, :))
307+
centeri_ = merge( x(i, :) - mean(x(i, :), mask = mask_),&
308+
0._dp,&
309+
mask_)
310+
centerj_ = merge( x(j, :) - mean(x(j, :), mask = mask_),&
311+
0._dp,&
312+
mask_)
313+
314+
res(j, i) = dot_product( centerj_, centeri_)&
315+
/sqrt(dot_product( centeri_, centeri_)*&
316+
dot_product( centerj_, centerj_))
317+
end do
318+
end do
319+
case default
320+
call error_stop("ERROR (corr): wrong dimension")
321+
end select
322+
323+
end function ${RName}$
324+
#:endfor
305325

306326

307327
end submodule

0 commit comments

Comments
 (0)