-
Notifications
You must be signed in to change notification settings - Fork 139
Description
The number of streams can be influenced through the max_concurrent_streams
configuration variable.
There are special values:
-
$> 0$ : Use at most streams (probably buggy). -
$= 0$ Determine the number of streams automatically (definitely buggy). -
$= -1$ : Use the default stream.
When using one single stream, i.e. max_concurrent_streams := 1
, then there is a synchronization at the end, i.e. before going back to Python there is a {cuda, hip}StreamSynchronize()
call at the end of the function.
Thus, CompiledSDFG.__call__()
will only return after everything has been computed.
However, when using the default stream there is no such call, i.e. CompiledSDFG.__call__()
will return immediately after all kernel has been launched.
Furthermore, looking at generate_state()
indicates that synchronization ignores everything that runs on the default stream.
Thus if there exists an SDFG that mixes "user allocated streams" and the default stream synchronization will most likely fail.
Here is a reproducer:
import dace
import uuid
def make_sdfg(
sdfg_name: str,
with_tasklet: bool,
) -> dace.SDFG:
sdfg = dace.SDFG(sdfg_name + f"_{str(uuid.uuid1()).replace('-', '_')}")
R = sdfg.add_symbol("R", dace.int32)
X_size = sdfg.add_symbol("X_size", dace.int32)
sdfg.add_scalar("T_GPU", dace.int32, storage=dace.StorageType.GPU_Global)
sdfg.add_scalar("T", dace.int32, transient=True)
sdfg.add_array("X", [X_size], dace.int32)
first_state = sdfg.add_state()
t_gpu_1 = first_state.add_access("T_GPU")
t_cpu_1 = first_state.add_access("T")
first_state.add_mapped_tasklet(
"write",
map_ranges={"i": "0"},
inputs={},
code="val = 10",
outputs={"val": dace.Memlet("T_GPU[i]")},
output_nodes={t_gpu_1},
external_edges=True,
schedule=dace.ScheduleType.GPU_Device,
)
first_state.add_nedge(t_gpu_1, t_cpu_1, dace.Memlet("T_GPU[0] -> [0]"))
if with_tasklet:
sdfg.add_array("U", shape=(1,), dtype=dace.int32, transient=False)
u_cpu_1 = first_state.add_access("U")
tlet1 = first_state.add_tasklet(
"cpu_computation", inputs={"__in"}, code="__out = __in + 1", outputs={"__out"}
)
first_state.add_edge(t_cpu_1, None, tlet1, "__in", dace.Memlet("T[0]"))
first_state.add_edge(tlet1, "__out", u_cpu_1, None, dace.Memlet("U[0]"))
# The second map does not need to be on GPU, it just has to use the value that is
# computed by the first GPU Map.
second_state = sdfg.add_state_after(first_state)
second_state.add_mapped_tasklet(
"compute",
map_ranges=dict(i=f"0:{R}"),
code="val = 1.0",
inputs={},
outputs={"val": dace.Memlet(data="X", subset="i")},
external_edges=True,
)
sdfg.out_edges(first_state)[0].data.assignments["R"] = "T"
sdfg.out_edges(first_state)[0].data.assignments["S"] = "False"
return sdfg
with dace.config.set_temporary("compiler.cuda.max_concurrent_streams", value=1):
sdfg1 = make_sdfg("sdfg_with_one_stream", True)
csdfg1 = sdfg1.compile()
print(f"csdfg1: {csdfg1.filename}")
#sdfg1.view()
with dace.config.set_temporary("compiler.cuda.max_concurrent_streams", value=-1):
sdfg2 = make_sdfg("sdfg_with_using_the_default_stream", True)
csdfg2 = sdfg2.compile()
print(f"csdfg2: {csdfg2.filename}")
with dace.config.set_temporary("compiler.cuda.max_concurrent_streams", value=1):
sdfg3 = make_sdfg("sdfg_with_one_stream_but_without_tlet", False)
csdfg3 = sdfg3.compile()
print(f"csdfg3: {csdfg3.filename}")
#sdfg3.view()
Which will generate the following SDFG:

Here the GPU Map in the first state computes something that is then used on the interstate edge or inside a Tasklet.
Thus a synchronization is needed before that data is used.
If we are suing a single stream then this synchronization is indeed there:
...
{
__dace_runkernel_write_map_0_0_2(__state, T_GPU);
DACE_GPU_CHECK(cudaMemcpyAsync(&T, T_GPU, 1 * sizeof(int), cudaMemcpyDeviceToHost, __state->gpu_context->streams[0]));
DACE_GPU_CHECK(cudaStreamSynchronize(__state->gpu_context->streams[0]));
{
int __in = T;
int __out;
///////////////////
// Tasklet code (cpu_computation)
__out = (__in + 1);
///////////////////
U[0] = __out;
}
}
R = T;
S = false;
...
But when using the default stream, i.e. max_concurrent_streams
set to -1
, then we get:
...
{
__dace_runkernel_write_map_0_0_2(__state, T_GPU);
DACE_GPU_CHECK(cudaMemcpyAsync(&T, T_GPU, 1 * sizeof(int), cudaMemcpyDeviceToHost, nullptr));
{
int __in = T;
int __out;
///////////////////
// Tasklet code (cpu_computation)
__out = (__in + 1);
///////////////////
U[0] = __out;
}
}
R = T;
S = false;
...
without the synchronization.