Skip to content

Commit ce8fb0a

Browse files
JeffBezansonshashi
authored andcommitted
update reduce to 1.0-style keyword args
1 parent 4e0d676 commit ce8fb0a

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ let
8787
b = rand(0:0.001:2, N)
8888
c = rand(N)
8989
t4 = IndexedTable(a, b, c)
90-
@bench "dim-1" reducedim(+, $t4, 1)
91-
@bench "dim-2" reducedim(+, $t4, 2)
90+
@bench "dim-1" reduce(+, $t4, dims=1)
91+
@bench "dim-2" reduce(+, $t4, dims=2)
9292
@bench "vec-dim-1" reducedim_vec(+, $t4, 1)
9393
@bench "vec-dim-2" reducedim_vec(+, $t4, 2)
9494
end

src/reduce.jl

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,24 @@ julia> reduce(@NT(xsum=:x=>+, negtsum=(:t=>-)=>+), t)
101101
See [`Selection`](@ref) for more on what selectors can be specified. Here since each output can select its own input, `select` keyword is unsually unnecessary. If specified, the slections in the reducer tuple will be done over the result of selecting with the `select` argument.
102102
103103
"""
104-
function reduce(f, t::Dataset; select=valuenames(t))
104+
function reduce(f, t::NextTable; select=valuenames(t), kws...)
105+
if haskey(kws, :init)
106+
return _reduce_select_init(f, t, select, kws.data.init)
107+
end
108+
_reduce_select(f, t, select)
109+
end
110+
111+
function _reduce_select(f, t::Dataset, select)
105112
fs, input, T = init_inputs(f, rows(t, select), reduced_type, false)
106113
acc = init_first(fs, input[1])
107114
_reduce(fs, input, acc, 2)
108115
end
109116

110-
function reduce(f, v0, t::Dataset; select=valuenames(t))
117+
function _reduce_select_init(f, t::Dataset, select, v0)
111118
fs, input, T = init_inputs(f, rows(t, select), reduced_type, false)
112119
_reduce(fs, input, v0, 1)
113120
end
114121

115-
@deprecate reduce(f, t::Dataset, v0; select=valuenames(t)) reduce(f, v0, t::Dataset; select=select)
116-
117122
function _reduce(fs, input, acc, start)
118123
@inbounds @simd for i=start:length(input)
119124
acc = _apply(fs, acc, input[i])
@@ -596,16 +601,29 @@ y │
596601
597602
```
598603
"""
599-
function Base.reduce(f, x::NDSparse, dims)
600-
keep = setdiff([1:ndims(x);], map(d->fieldindex(x.index.columns,d), dims))
601-
if isempty(keep)
602-
throw(ArgumentError("to remove all dimensions, use `reduce(f, A)`"))
604+
function Base.reduce(f, x::NDSparse; kws...)
605+
if haskey(kws, :dims)
606+
if haskey(kws, :select) || haskey(kws, :init)
607+
throw(ArgumentError("select and init keyword arguments cannot be used with dims"))
608+
end
609+
dims = kws.data.dims
610+
if dims isa Symbol
611+
dims = [dims]
612+
end
613+
keep = setdiff([1:ndims(x);], map(d->fieldindex(x.index.columns,d), dims))
614+
if isempty(keep)
615+
throw(ArgumentError("to remove all dimensions, use `reduce(f, A)`"))
616+
end
617+
return groupreduce(f, x, (keep...,))
618+
else
619+
select = get(kws, :select, valuenames(x))
620+
if haskey(kws, :init)
621+
return _reduce_select_init(f, x, select, kws.data.init)
622+
end
623+
return _reduce_select(f, x, select)
603624
end
604-
groupreduce(f, x, (keep...,))
605625
end
606626

607-
Base.reduce(f, x::NDSparse, dims::Symbol) = reduce(f, x, [dims])
608-
609627
"""
610628
`reducedim_vec(f::Function, arr::NDSparse, dims)`
611629

test/test_core.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,14 @@ let a = rand(5,5,5)
178178
for dims in ([2,3], [1], [2])
179179
r = dropdims(reduce(+, a; dims=dims), dims=(dims...,))
180180
asnd = convert(NDSparse,a)
181-
b = reduce(+, asnd, dims)
181+
b = reduce(+, asnd, dims=dims)
182182
bv = reducedim_vec(sum, asnd, dims)
183183
c = convert(NDSparse, r)
184184
@test b.index == c.index == bv.index
185185
@test b.data c.data
186186
@test bv.data c.data
187187
end
188-
@test_throws ArgumentError reduce(+, convert(NDSparse,a), [1,2,3])
188+
@test_throws ArgumentError reduce(+, convert(NDSparse,a), dims=[1,2,3])
189189
end
190190

191191
for a in (rand(2,2), rand(3,5))
@@ -727,8 +727,8 @@ end
727727

728728
@testset "reducedim" begin
729729
x = ndsparse((x = [1, 1, 1, 2, 2, 2], y = [1, 2, 2, 1, 2, 2], z = [1, 1, 2, 1, 1, 2]), [1, 2, 3, 4, 5, 6])
730-
@test reduce(+, x, 1) == ndsparse((y = [1, 2, 2], z = [1, 1, 2]), [5, 7, 9])
731-
@test reduce(+, x, (1, 3)) == ndsparse((y = [1, 2],), [5, 16])
730+
@test reduce(+, x, dims=1) == ndsparse((y = [1, 2, 2], z = [1, 1, 2]), [5, 7, 9])
731+
@test reduce(+, x, dims=(1, 3)) == ndsparse((y = [1, 2],), [5, 16])
732732
end
733733

734734
@testset "select" begin
@@ -883,7 +883,7 @@ using OnlineStats
883883

884884
t = table([0.1, 0.5, 0.75], [0, 1, 2], names=[:t, :x])
885885
@test reduce(+, t, select=:t) == 1.35
886-
@test reduce(+, 1.0, t, select = :t) == 2.35
886+
@test reduce(+, t, init = 1.0, select = :t) == 2.35
887887
@test reduce(((a, b)->(t = a.t + b.t, x = a.x + b.x)), t) == (t = 1.35, x = 3)
888888
@test value(reduce(Mean(), t, select=:t)) == 0.45
889889
y = reduce((min, max), t, select=:x)

0 commit comments

Comments
 (0)