Skip to content

Commit 85739ee

Browse files
committed
Add datatype for MPI Bcast.
1 parent 815e17d commit 85739ee

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

sigpy/backend.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,17 +265,21 @@ def reduce(self, input, root=0):
265265
else:
266266
self.mpi_comm.Reduce(cpu_input, None, root=root)
267267

268-
def bcast(self, input, root=0):
268+
def bcast(self, input, root=0, datatype=None):
269269
"""Broadcast from root to other nodes.
270270
271271
Args:
272272
input (array): input array.
273273
root (int): root node rank.
274+
datatype (int): MPI datatype for broadcasting.
274275
275276
"""
277+
if config.mpi4py_enabled:
278+
datatype = MPI.COMPLEX
279+
276280
if self.size > 1:
277281
cpu_input = to_device(input, cpu_device)
278-
self.mpi_comm.Bcast(cpu_input, root=root)
282+
self.mpi_comm.Bcast((cpu_input, datatype), root=root)
279283
copyto(input, cpu_input)
280284

281285
def gatherv(self, input, root=0):

0 commit comments

Comments
 (0)