Skip to content

Commit 22a4265

Browse files
committed
added tests for alltoall / alltoallv
1 parent 27cecf4 commit 22a4265

File tree

2 files changed

+130
-3
lines changed

2 files changed

+130
-3
lines changed

collcomm.scm

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,6 @@ C_word MPI_scatterv_f64vector (C_word sendbuf, C_word sendlengths,
12081208
(let ((myself (MPI:comm-rank comm))
12091209
(nprocs (MPI:comm-size comm))
12101210
(tysize (MPI:type-size ty)))
1211-
(print "scatter: size v = " (blob-size v))
12121211
(if (= root myself)
12131212
;; If this is the root process, scatter the data
12141213
(if (<= (* nprocs sendcount tysize) (blob-size v))
@@ -2772,8 +2771,8 @@ C_word MPI_alltoall_u8vector (C_word data, C_word sendcount, C_word recv, C_word
27722771

27732772
vect = C_c_u8vector(data);
27742773
slen = (int)C_num_to_int (sendcount);
2775-
2776-
MPI_Altoall(vect, slen, MPI_UNSIGNED_CHAR, vrecv, rlen, MPI_UNSIGNED_CHAR, Comm_val(comm));
2774+
2775+
MPI_Alltoall(vect, slen, MPI_UNSIGNED_CHAR, vrecv, rlen, MPI_UNSIGNED_CHAR, Comm_val(comm));
27772776

27782777
C_return (recv);
27792778
}

tests/mpitest.scm

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,28 @@
181181
(define vvintdata (list-tabulate size (lambda (i) (list-tabulate (+ i 1) (lambda (j) (+ (* 10 i) j))))))
182182
(define vvflodata (list-tabulate size (lambda (i) (list-tabulate (+ i 1) (lambda (j) (+ i (* 0.1 j)))))))
183183

184+
(define allvsdata (list-tabulate size
185+
(lambda (i)
186+
(list-tabulate vsize
187+
(lambda (j) (integer->char (+ myrank i 97)))))))
188+
(define allvintdata (list-tabulate size (lambda (i) (list-tabulate vsize (lambda (j) (+ (* 10 (+ myrank i)) j))))))
189+
(define allvflodata (list-tabulate size (lambda (i) (list-tabulate vsize (lambda (j) (+ myrank i (* 0.1 j)))))))
190+
191+
(define allvvsdata (list-tabulate size
192+
(lambda (i)
193+
(list-tabulate (+ i 1)
194+
(lambda (j) (integer->char (+ myrank i 97)))))))
195+
(define allvvintdata (list-tabulate size (lambda (i) (list-tabulate (+ i 1) (lambda (j) (+ (* 10 (+ myrank i)) j))))))
196+
(define allvvflodata (list-tabulate size (lambda (i) (list-tabulate (+ i 1) (lambda (j) (+ myrank i (* 0.1 j)))))))
197+
198+
(define rallvvsdata (list-tabulate size
199+
(lambda (i)
200+
(list-tabulate (+ myrank 1)
201+
(lambda (j) (integer->char (+ myrank i 97)))))))
202+
(define rallvvintdata (list-tabulate size (lambda (i) (list-tabulate (+ myrank 1) (lambda (j) (+ (* 10 (+ myrank i)) j))))))
203+
(define rallvvflodata (list-tabulate size (lambda (i) (list-tabulate (+ myrank 1) (lambda (j) (+ myrank i (* 0.1 j)))))))
204+
205+
184206
(test-group "MPI test 1"
185207

186208
(if (zero? myrank)
@@ -500,6 +522,112 @@
500522
)
501523

502524

525+
(test-group "MPI test alltoall / alltoallv"
526+
527+
;; All to all
528+
529+
(let* ((test-alltoall
530+
(lambda (alltoall data total)
531+
(print myrank ": alltoall " data)
532+
(let ((res (alltoall data vsize comm-world)))
533+
(print myrank ": received (alltoall) "
534+
(map (lambda (x) (if (blob? x) (blob->string x) x)) res))
535+
(test res total))
536+
(MPI:barrier comm-world))))
537+
538+
(test-alltoall MPI:alltoall-bytevector
539+
(string->blob
540+
(list->string
541+
(concatenate allvsdata)))
542+
(map (lambda (lst)
543+
(string->blob
544+
(list->string lst)))
545+
allvsdata))
546+
547+
(test-alltoall MPI:alltoall-s8vector
548+
(list->s8vector (concatenate allvintdata))
549+
(map list->s8vector allvintdata))
550+
(test-alltoall MPI:alltoall-u8vector
551+
(list->u8vector (concatenate allvintdata))
552+
(map list->u8vector allvintdata))
553+
(test-alltoall MPI:alltoall-s16vector
554+
(list->s16vector (concatenate allvintdata))
555+
(map list->s16vector allvintdata))
556+
(test-alltoall MPI:alltoall-u16vector
557+
(list->u16vector (concatenate allvintdata))
558+
(map list->u16vector allvintdata))
559+
(test-alltoall MPI:alltoall-s32vector
560+
(list->s32vector (concatenate allvintdata))
561+
(map list->s32vector allvintdata))
562+
(test-alltoall MPI:alltoall-u32vector
563+
(list->u32vector (concatenate allvintdata))
564+
(map list->u32vector allvintdata))
565+
(test-alltoall MPI:alltoall-f32vector
566+
(list->f32vector (concatenate allvflodata))
567+
(map list->f32vector allvflodata))
568+
(test-alltoall MPI:alltoall-f64vector
569+
(list->f64vector (concatenate allvflodata))
570+
(map list->f64vector allvflodata))
571+
)
572+
573+
(let* ((test-alltoallv
574+
(lambda (alltoallv data sendlens total)
575+
(print myrank ": alltoallv " data " " sendlens)
576+
(let ((res (alltoallv data sendlens comm-world)))
577+
(print myrank ": received (alltoallv) "
578+
(map (lambda (x) (if (blob? x) (blob->string x) x)) res))
579+
(test res total))
580+
(MPI:barrier comm-world))))
581+
582+
(test-alltoallv MPI:alltoallv-bytevector
583+
(string->blob
584+
(list->string
585+
(concatenate allvvsdata)))
586+
(list->s32vector (map length allvvsdata))
587+
(map (lambda (lst)
588+
(string->blob
589+
(list->string lst)))
590+
rallvvsdata))
591+
592+
(test-alltoallv MPI:alltoallv-s8vector
593+
(list->s8vector (concatenate allvvintdata))
594+
(list->s32vector (map length allvvintdata))
595+
(map list->s8vector rallvvintdata))
596+
(test-alltoallv MPI:alltoallv-u8vector
597+
(list->u8vector (concatenate allvvintdata))
598+
(list->s32vector (map length allvvintdata))
599+
(map list->u8vector rallvvintdata))
600+
(test-alltoallv MPI:alltoallv-s16vector
601+
(list->s16vector (concatenate allvvintdata))
602+
(list->s32vector (map length allvvintdata))
603+
(map list->s16vector rallvvintdata))
604+
(test-alltoallv MPI:alltoallv-u16vector
605+
(list->u16vector (concatenate allvvintdata))
606+
(list->s32vector (map length allvvintdata))
607+
(map list->u16vector rallvvintdata))
608+
(test-alltoallv MPI:alltoallv-s32vector
609+
(list->s32vector (concatenate allvvintdata))
610+
(list->s32vector (map length allvvintdata))
611+
(map list->s32vector rallvvintdata))
612+
(test-alltoallv MPI:alltoallv-u32vector
613+
(list->u32vector (concatenate allvvintdata))
614+
(list->s32vector (map length allvvintdata))
615+
(map list->u32vector rallvvintdata))
616+
(test-alltoallv MPI:alltoallv-f32vector
617+
(list->f32vector (concatenate allvvflodata))
618+
(list->s32vector (map length allvvflodata))
619+
(map list->f32vector rallvvflodata))
620+
(test-alltoallv MPI:alltoallv-f64vector
621+
(list->f64vector (concatenate allvvflodata))
622+
(list->s32vector (map length allvvflodata))
623+
(map list->f64vector rallvvflodata))
624+
625+
)
626+
627+
)
628+
629+
630+
503631
(test-group "MPI test reduce/reduce all"
504632

505633
;; Reduce

0 commit comments

Comments
 (0)