Skip to content

Commit 74e06f3

Browse files
Merge branch 'broadcasting'
2 parents 5468e19 + 0355ca4 commit 74e06f3

20 files changed

+3242
-174
lines changed

src/OrdinaryDiffEq.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ module OrdinaryDiffEq
3030
u_modified!,savevalues!,add_tstop!,add_saveat!,set_reltol!,
3131
set_abstol!
3232

33+
macro tight_loop_macros(ex)
34+
:($(esc(ex)))
35+
end
36+
3337
include("misc_utils.jl")
3438
include("algorithms.jl")
3539
include("alg_utils.jl")

src/dense/generic_dense.jl

Lines changed: 81 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,12 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
353353
"""
354354
@inline function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
355355
if typeof(idxs) <: Void
356-
out = @. (1-Θ)*y₀+Θ*y₁+Θ*-1)*((1-2Θ)*(y₁-y₀)+-1)*dt*k[1] + Θ*dt*k[2])
356+
#out = @. (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
357+
out = (1-Θ)*y₀+Θ*y₁+Θ*-1)*((1-2Θ)*(y₁-y₀)+-1)*dt*k[1] + Θ*dt*k[2])
357358
else
358-
out = similar(y₀,indices(idxs))
359-
@views @. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
359+
#out = similar(y₀,indices(idxs))
360+
#@views @. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
361+
@views out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
360362
end
361363
out
362364
end
@@ -366,10 +368,12 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
366368
"""
367369
@inline function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{1}}) # Default interpolant is Hermite
368370
if typeof(idxs) <: Void
369-
out = @. k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
371+
#out = @. k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
372+
out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
370373
else
371-
out = similar(y₀,indices(idxs))
372-
@views @. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
374+
#out = similar(y₀,indices(idxs))
375+
#@views @. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
376+
@views out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
373377
end
374378
out
375379
end
@@ -379,10 +383,12 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
379383
"""
380384
@inline function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{2}}) # Default interpolant is Hermite
381385
if typeof(idxs) <: Void
382-
out = @. (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
386+
#out = @. (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
387+
out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
383388
else
384-
out = similar(y₀,indices(idxs))
385-
@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
389+
#out = similar(y₀,indices(idxs))
390+
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
391+
@views out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
386392
end
387393
out
388394
end
@@ -392,10 +398,12 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
392398
"""
393399
@inline function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{3}}) # Default interpolant is Hermite
394400
if typeof(idxs) <: Void
395-
out = @. (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
401+
#out = @. (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
402+
out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
396403
else
397-
out = similar(y₀,indices(idxs))
398-
@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
404+
#out = similar(y₀,indices(idxs))
405+
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
406+
@views out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
399407
end
400408
out
401409
end
@@ -409,9 +417,15 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
409417
if out == nothing
410418
return (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
411419
elseif idxs == nothing
412-
@. out = (1-Θ)*y₀+Θ*y₁+Θ*-1)*((1-2Θ)*(y₁-y₀)+-1)*dt*k[1] + Θ*dt*k[2])
420+
#@. out = (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
421+
@inbounds for i in eachindex(out)
422+
out[i] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k[1][i] + Θ*dt*k[2][i])
423+
end
413424
else
414-
@views @. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
425+
#@views @. out = (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
426+
@inbounds for (j,i) in enumerate(idxs)
427+
out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k[1][i] + Θ*dt*k[2][i])
428+
end
415429
end
416430
end
417431

@@ -422,9 +436,15 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
422436
if out == nothing
423437
return k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
424438
elseif idxs == nothing
425-
@. out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
439+
#@. out = k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
440+
@inbounds for i in eachindex(out)
441+
out[i] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
442+
end
426443
else
427-
@views @. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
444+
#@views @. out = k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt
445+
@inbounds for (j,i) in enumerate(idxs)
446+
out[j] = k[1][i] + Θ*(-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(3*dt*k[1][i] + 3*dt*k[2][i] + 6*y₀[i] - 6*y₁[i]) + 6*y₁[i])/dt
447+
end
428448
end
429449
end
430450

@@ -435,9 +455,15 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
435455
if out == nothing
436456
return (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
437457
elseif idxs == nothing
438-
@. out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
458+
#@. out = (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
459+
@inbounds for i in eachindex(out)
460+
out[i] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
461+
end
439462
else
440-
@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
463+
#@views @. out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt)
464+
@inbounds for (j,i) in enumerate(idxs)
465+
out[j] = (-4*dt*k[1][i] - 2*dt*k[2][i] - 6*y₀[i] + Θ*(6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i]) + 6*y₁[i])/(dt*dt)
466+
end
441467
end
442468
end
443469

@@ -448,39 +474,41 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
448474
if out == nothing
449475
return (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
450476
elseif idxs == nothing
451-
@. out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
477+
# @. out = (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
478+
for i in eachindex(out)
479+
out[i] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
480+
end
452481
else
453-
@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
482+
#@views @. out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt)
483+
for (j,i) in enumerate(idxs)
484+
out[j] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt)
485+
end
454486
end
455487
end
456488

457489
######################## Linear Interpolants
458490

459491
@inline function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{0}})
460-
if typeof(y₀) <: AbstractArray
461-
if typeof(idxs) <: Void
462-
out = @. (1-Θ)*y₀ + Θ*y₁
463-
else
464-
out = similar(y₀,indices(idxs))
465-
Θm1 = (1-Θ)
466-
@views @. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
467-
end
492+
if typeof(idxs) <: Void
493+
#out = @. (1-Θ)*y₀ + Θ*y₁
494+
out = (1-Θ)*y₀ + Θ*y₁
468495
else
469-
out = @. (1-Θ)*y₀ + Θ*y₁
496+
#out = similar(y₀,indices(idxs))
497+
Θm1 = (1-Θ)
498+
#@views @. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
499+
@views out = Θm1*y₀[idxs] + Θ*y₁[idxs]
470500
end
471501
out
472502
end
473503

474504
@inline function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{1}})
475-
if typeof(y₀) <: AbstractArray
476-
if typeof(idxs) <: Void
477-
out = @. (y₁ - y₀)/dt
478-
else
479-
out = similar(y₀,indices(idxs))
480-
@views @. out = (y₁[idxs] - y₀[idxs])/dt
481-
end
505+
if typeof(idxs) <: Void
506+
#out = @. (y₁ - y₀)/dt
507+
out = (y₁ - y₀)/dt
482508
else
483-
out = @. (y₁ - y₀)/dt
509+
#out = similar(y₀,indices(idxs))
510+
#@views @. out = (y₁[idxs] - y₀[idxs])/dt
511+
@views out = (y₁[idxs] - y₀[idxs])/dt
484512
end
485513
out
486514
end
@@ -493,9 +521,15 @@ Linear Interpolation
493521
if out == nothing
494522
return Θm1*y₀[idxs] + Θ*y₁[idxs]
495523
elseif idxs == nothing
496-
@. out = Θm1*y₀ + Θ*y₁
524+
#@. out = Θm1*y₀ + Θ*y₁
525+
@inbounds for i in eachindex(out)
526+
out[i] = Θm1*y₀[i] + Θ*y₁[i]
527+
end
497528
else
498-
@views @. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
529+
#@views @. out = Θm1*y₀[idxs] + Θ*y₁[idxs]
530+
@inbounds for (j,i) in enumerate(idxs)
531+
out[j] = Θm1*y₀[i] + Θ*y₁[i]
532+
end
499533
end
500534
end
501535

@@ -506,8 +540,14 @@ Linear Interpolation
506540
if out == nothing
507541
return (y₁[idxs] - y₀[idxs])/dt
508542
elseif idxs == nothing
509-
@. out = (y₁ - y₀)/dt
543+
#@. out = (y₁ - y₀)/dt
544+
@inbounds for i in eachindex(out)
545+
out[i] = (y₁[i] - y₀[i])/dt
546+
end
510547
else
511-
@views @. out = (y₁[idxs] - y₀[idxs])/dt
548+
#@views @. out = (y₁[idxs] - y₀[idxs])/dt
549+
@inbounds for (j,i) in enumerate(idxs)
550+
out[j] = (y₁[i] - y₀[i])/dt
551+
end
512552
end
513553
end

0 commit comments

Comments
 (0)