Skip to content

Commit 151bf73

Browse files
(dsl): Support Sampler aggregation (#649)
1 parent 047c1ca commit 151bf73

File tree

8 files changed

+338
-7
lines changed

8 files changed

+338
-7
lines changed

modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,16 @@ import zio.elasticsearch.query.sort.SortOrder._
3333
import zio.elasticsearch.query.sort.SourceType.NumberType
3434
import zio.elasticsearch.query.{Distance, FunctionScoreBoostMode, FunctionScoreFunction, InnerHits}
3535
import zio.elasticsearch.request.{CreationOutcome, DeletionOutcome}
36-
import zio.elasticsearch.result.{FilterAggregationResult, Item, MaxAggregationResult, UpdateByQueryResult}
36+
import zio.elasticsearch.result.{
37+
FilterAggregationResult,
38+
Item,
39+
MaxAggregationResult,
40+
SamplerAggregationResult,
41+
SumAggregationResult,
42+
TermsAggregationBucketResult,
43+
TermsAggregationResult,
44+
UpdateByQueryResult
45+
}
3746
import zio.elasticsearch.script.{Painless, Script}
3847
import zio.json.ast.Json.{Arr, Str}
3948
import zio.schema.codec.JsonCodec
@@ -408,6 +417,74 @@ object HttpExecutorSpec extends IntegrationSpec {
408417
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
409418
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
410419
),
420+
test("aggregate using sampler aggregation with sum and terms sub aggregations") {
421+
(
422+
"sampler_agg",
423+
SamplerAggregationResult(
424+
docCount = 4,
425+
subAggregations = Map(
426+
"total_sum_field" -> SumAggregationResult(value = 50.0),
427+
"string_categories" -> TermsAggregationResult(
428+
docErrorCount = 0,
429+
sumOtherDocCount = 0,
430+
buckets = Chunk(
431+
TermsAggregationBucketResult(key = "abc", docCount = 1, subAggregations = Map.empty),
432+
TermsAggregationBucketResult(key = "def", docCount = 1, subAggregations = Map.empty),
433+
TermsAggregationBucketResult(key = "ghi", docCount = 1, subAggregations = Map.empty),
434+
TermsAggregationBucketResult(key = "jkl", docCount = 1, subAggregations = Map.empty)
435+
)
436+
)
437+
)
438+
)
439+
)
440+
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
441+
(docIdA, docA, docIdB, docB, docIdC, docC) =>
442+
for {
443+
_ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll))
444+
documentA = docA.copy(stringField = "abc", intField = 10)
445+
documentB = docB.copy(stringField = "def", intField = 20)
446+
documentC = docC.copy(stringField = "ghi", intField = 15)
447+
_ <- Executor.execute(ElasticRequest.upsert[TestDocument](firstSearchIndex, docIdA, documentA))
448+
_ <- Executor.execute(ElasticRequest.upsert[TestDocument](firstSearchIndex, docIdB, documentB))
449+
_ <- Executor.execute(
450+
ElasticRequest.upsert[TestDocument](firstSearchIndex, docIdC, documentC).refreshTrue
451+
)
452+
aggregation = samplerAggregation(
453+
"sampler_agg",
454+
sumAggregation("total_sum_field", TestDocument.intField)
455+
).withSubAgg(termsAggregation("string_categories", TestDocument.stringField.keyword))
456+
.maxDocumentsPerShard(100)
457+
aggsRes <-
458+
Executor
459+
.execute(ElasticRequest.aggregate(selectors = firstSearchIndex, aggregation = aggregation))
460+
.aggregations
461+
.map(_.head)
462+
463+
expectedResult =
464+
(
465+
"sampler_agg",
466+
SamplerAggregationResult(
467+
docCount = 3,
468+
subAggregations = Map(
469+
"total_sum_field" -> SumAggregationResult(value = 45.0),
470+
"string_categories" -> TermsAggregationResult(
471+
docErrorCount = 0,
472+
sumOtherDocCount = 0,
473+
buckets = Chunk(
474+
TermsAggregationBucketResult(key = "abc", docCount = 1, subAggregations = Map.empty),
475+
TermsAggregationBucketResult(key = "def", docCount = 1, subAggregations = Map.empty),
476+
TermsAggregationBucketResult(key = "ghi", docCount = 1, subAggregations = Map.empty)
477+
)
478+
)
479+
)
480+
)
481+
)
482+
} yield assert(aggsRes)(equalTo(expectedResult))
483+
}
484+
} @@ around(
485+
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
486+
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
487+
),
411488
test("aggregate using stats aggregation") {
412489
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
413490
(firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) =>
@@ -795,6 +872,57 @@ object HttpExecutorSpec extends IntegrationSpec {
795872
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
796873
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
797874
),
875+
test("search using sampler aggregation") {
876+
val expectedAggResult = SamplerAggregationResult(
877+
docCount = 2,
878+
subAggregations = Map(
879+
"sampled_strings" -> TermsAggregationResult(
880+
docErrorCount = 0,
881+
sumOtherDocCount = 0,
882+
buckets = Chunk(
883+
TermsAggregationBucketResult(key = "zio", docCount = 1, subAggregations = Map.empty),
884+
TermsAggregationBucketResult(key = "zio-elasticsearch", docCount = 1, subAggregations = Map.empty)
885+
)
886+
)
887+
)
888+
)
889+
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
890+
(docIdA, docA, docIdB, docB, docIdC, docC) =>
891+
val documentA = docA.copy(stringField = "zio")
892+
val documentB = docB.copy(stringField = "elasticsearch")
893+
val documentC = docC.copy(stringField = "zio-elasticsearch")
894+
val expectedSearchDocs = Chunk(documentA, documentC)
895+
for {
896+
_ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll))
897+
_ <- Executor.execute(ElasticRequest.upsert[TestDocument](firstSearchIndex, docIdA, documentA))
898+
_ <- Executor.execute(ElasticRequest.upsert[TestDocument](firstSearchIndex, docIdB, documentB))
899+
_ <- Executor.execute(
900+
ElasticRequest.upsert[TestDocument](firstSearchIndex, docIdC, documentC).refreshTrue
901+
)
902+
searchQuery = matches(TestDocument.stringField, "zio")
903+
aggregation = samplerAggregation(
904+
"sampler_agg",
905+
termsAggregation("sampled_strings", TestDocument.stringField.keyword)
906+
)
907+
.maxDocumentsPerShard(2)
908+
res <- Executor.execute(
909+
ElasticRequest
910+
.search(
911+
selectors = firstSearchIndex,
912+
query = searchQuery,
913+
aggregation = aggregation
914+
)
915+
)
916+
docs <- res.documentAs[TestDocument]
917+
samplerAgg <- res.aggregation("sampler_agg")
918+
} yield assert(docs.length)(equalTo(2)) &&
919+
assert(docs.toSet)(equalTo(expectedSearchDocs.toSet)) &&
920+
assert(samplerAgg)(isSome(equalTo(expectedAggResult)))
921+
}
922+
} @@ around(
923+
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
924+
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
925+
),
798926
test(
799927
"search using match all query with terms aggregations, nested max aggregation and nested bucketSelector aggregation"
800928
) {

modules/library/src/main/scala/zio/elasticsearch/ElasticAggregation.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,26 @@ object ElasticAggregation {
345345
final def percentilesAggregation(name: String, field: String): PercentilesAggregation =
346346
Percentiles(name = name, field = field, percents = Chunk.empty, missing = None)
347347

348+
/**
349+
* Constructs an instance of [[zio.elasticsearch.aggregation.SamplerAggregation]] using the specified parameters.
350+
*
351+
* @param name
352+
* the name of the aggregation
353+
* @param agg
354+
* the first required sub-aggregation to be included in the sampler
355+
* @param aggs
356+
* additional sub-aggregations to be included in the sampler
357+
* @return
358+
* an instance of [[zio.elasticsearch.aggregation.SamplerAggregation]] that represents sampler aggregation to be
359+
* performed. This aggregation has a default `shard_size` of `100` documents per shard.
360+
*/
361+
final def samplerAggregation(
362+
name: String,
363+
agg: SingleElasticAggregation,
364+
aggs: SingleElasticAggregation*
365+
): SamplerAggregation =
366+
Sampler(name = name, shardSizeValue = 100, subAggregations = agg +: aggs)
367+
348368
/**
349369
* Constructs a type-safe instance of [[zio.elasticsearch.aggregation.StatsAggregation]] using the specified
350370
* parameters.

modules/library/src/main/scala/zio/elasticsearch/aggregation/Aggregations.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,40 @@ private[elasticsearch] final case class Percentiles(
351351
}
352352
}
353353

354+
sealed trait SamplerAggregation extends SingleElasticAggregation with WithSubAgg[SamplerAggregation] {
355+
356+
/**
357+
* Sets the `shard_size` parameter for the [[zio.elasticsearch.aggregation.SamplerAggregation]]. This parameter
358+
* controls the maximum number of documents to be returned per shard.
359+
*
360+
* @param value
361+
* the maximum number of documents per shard
362+
* @return
363+
* an instance of the [[zio.elasticsearch.aggregation.SamplerAggregation]] enriched with the `shard_size` parameter.
364+
*/
365+
def maxDocumentsPerShard(value: Int): SamplerAggregation
366+
}
367+
368+
private[elasticsearch] final case class Sampler(
369+
name: String,
370+
shardSizeValue: Int,
371+
subAggregations: Seq[SingleElasticAggregation]
372+
) extends SamplerAggregation {
373+
self =>
374+
def maxDocumentsPerShard(value: Int): SamplerAggregation =
375+
self.copy(shardSizeValue = value)
376+
377+
def withSubAgg(aggregation: SingleElasticAggregation): SamplerAggregation =
378+
self.copy(subAggregations = aggregation +: subAggregations)
379+
380+
private[elasticsearch] def toJson: Json = {
381+
val samplerParamsContent: Obj = Obj("sampler" -> Obj("shard_size" -> shardSizeValue.toJson))
382+
val subAggsJson: Obj = Obj("aggs" -> subAggregations.map(_.toJson).reduce(_ merge _))
383+
384+
Obj(name -> (samplerParamsContent merge subAggsJson))
385+
}
386+
}
387+
354388
sealed trait StatsAggregation extends SingleElasticAggregation with HasMissing[StatsAggregation] with WithAgg
355389

356390
private[elasticsearch] final case class Stats(name: String, field: String, missing: Option[Double])

modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ object AggregationResponse {
7373
case FilterAggregationResponse(docCount, subAggregations) =>
7474
FilterAggregationResult(
7575
docCount = docCount,
76-
subAggregations = subAggregations.fold(Map[String, AggregationResult]())(_.map { case (key, response) =>
77-
(key, toResult(response))
78-
})
76+
subAggregations =
77+
subAggregations.map(_.map { case (key, response) => (key, toResult(response)) }).getOrElse(Map.empty)
7978
)
8079
case MaxAggregationResponse(value) =>
8180
MaxAggregationResult(value)
@@ -87,6 +86,11 @@ object AggregationResponse {
8786
PercentileRanksAggregationResult(values)
8887
case PercentilesAggregationResponse(values) =>
8988
PercentilesAggregationResult(values)
89+
case SamplerAggregationResponse(count, aggs) =>
90+
SamplerAggregationResult(
91+
docCount = count,
92+
subAggregations = aggs.map(_.map { case (key, response) => (key, toResult(response)) }).getOrElse(Map.empty)
93+
)
9094
case StatsAggregationResponse(count, min, max, avg, sum) =>
9195
StatsAggregationResult(count, min, max, avg, sum)
9296
case SumAggregationResponse(value) =>
@@ -99,9 +103,8 @@ object AggregationResponse {
99103
TermsAggregationBucketResult(
100104
docCount = b.docCount,
101105
key = b.key,
102-
subAggregations = b.subAggregations.fold(Map[String, AggregationResult]())(_.map { case (key, response) =>
103-
(key, toResult(response))
104-
})
106+
subAggregations =
107+
b.subAggregations.map(_.map { case (key, response) => (key, toResult(response)) }).getOrElse(Map.empty)
105108
)
106109
)
107110
)
@@ -169,6 +172,8 @@ private[elasticsearch] case class BucketDecoder(fields: Chunk[(String, Json)]) e
169172
)
170173
case str if str.contains("percentiles#") =>
171174
Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]))
175+
case str if str.contains("sampler#") =>
176+
Some(field -> data.unsafeAs[SamplerAggregationResponse](SamplerAggregationResponse.decoder))
172177
case str if str.contains("stats#") =>
173178
Some(
174179
field -> StatsAggregationResponse(
@@ -212,6 +217,8 @@ private[elasticsearch] case class BucketDecoder(fields: Chunk[(String, Json)]) e
212217
(field.split("#")(1), data.asInstanceOf[PercentileRanksAggregationResponse])
213218
case str if str.contains("percentiles#") =>
214219
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
220+
case str if str.contains("sampler#") =>
221+
(field.split("#")(1), data.asInstanceOf[SamplerAggregationResponse])
215222
case str if str.contains("stats#") =>
216223
(field.split("#")(1), data.asInstanceOf[StatsAggregationResponse])
217224
case str if str.contains("sum#") =>
@@ -320,6 +327,23 @@ private[elasticsearch] object PercentilesAggregationResponse {
320327
DeriveJsonDecoder.gen[PercentilesAggregationResponse]
321328
}
322329

330+
private[elasticsearch] final case class SamplerAggregationResponse(
331+
@jsonField("doc_count")
332+
docCount: Int,
333+
subAggregations: Option[Map[String, AggregationResponse]] = None
334+
) extends AggregationResponse
335+
336+
private[elasticsearch] object SamplerAggregationResponse {
337+
implicit val decoder: JsonDecoder[SamplerAggregationResponse] = Obj.decoder.mapOrFail { case Obj(fields) =>
338+
val bucketDecoder = BucketDecoder(fields)
339+
val allFields = bucketDecoder.allFields
340+
val docCount = allFields("doc_count").asInstanceOf[Int]
341+
val subAggs = bucketDecoder.subAggs
342+
343+
Right(SamplerAggregationResponse.apply(docCount, Option(subAggs).filter(_.nonEmpty)))
344+
}
345+
}
346+
323347
private[elasticsearch] final case class StatsAggregationResponse(
324348
count: Int,
325349
min: Double,

modules/library/src/main/scala/zio/elasticsearch/executor/response/SearchWithAggregationsResponse.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ private[elasticsearch] final case class SearchWithAggregationsResponse(
9292
PercentileRanksAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _)
9393
case str if str.contains("percentiles#") =>
9494
PercentilesAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _)
95+
case str if str.contains("sampler#") =>
96+
SamplerAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _)
9597
case str if str.contains("stats#") =>
9698
StatsAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _)
9799
case str if str.contains("sum#") =>

modules/library/src/main/scala/zio/elasticsearch/package.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,18 @@ package object elasticsearch extends IndexNameNewtype with IndexPatternNewtype w
156156
def asPercentilesAggregation(name: String): RIO[R, Option[PercentilesAggregationResult]] =
157157
aggregationAs[PercentilesAggregationResult](name)
158158

159+
/**
160+
* Executes the [[ElasticRequest.SearchRequest]] or the [[ElasticRequest.SearchAndAggregateRequest]].
161+
*
162+
* @param name
163+
* the name of the aggregation to retrieve
164+
* @return
165+
* a [[RIO]] effect that, when executed, will produce the aggregation as instance of
166+
* [[result.SamplerAggregationResult]].
167+
*/
168+
def asSamplerAggregation(name: String): RIO[R, Option[SamplerAggregationResult]] =
169+
aggregationAs[SamplerAggregationResult](name)
170+
159171
/**
160172
* Executes the [[ElasticRequest.SearchRequest]] or the [[ElasticRequest.SearchAndAggregateRequest]].
161173
*

modules/library/src/main/scala/zio/elasticsearch/result/AggregationResult.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,19 @@ final case class PercentileRanksAggregationResult private[elasticsearch] (values
6565
final case class PercentilesAggregationResult private[elasticsearch] (values: Map[String, Double])
6666
extends AggregationResult
6767

68+
final case class SamplerAggregationResult private[elasticsearch] (
69+
docCount: Int,
70+
subAggregations: Map[String, AggregationResult]
71+
) extends AggregationResult {
72+
73+
def subAggregationAs[A <: AggregationResult](aggName: String): Either[DecodingException, Option[A]] =
74+
subAggregations.get(aggName) match {
75+
case Some(agg: A) => Right(Some(agg))
76+
case Some(_) => Left(DecodingException(s"Aggregation with name $aggName was not of type you provided."))
77+
case None => Right(None)
78+
}
79+
}
80+
6881
final case class StatsAggregationResult private[elasticsearch] (
6982
count: Int,
7083
min: Double,

0 commit comments

Comments
 (0)