Skip to content

Commit 7315aec

Browse files
committed
feat: update web search code
Signed-off-by: yuluo-yx <yuluo08290126@gmail.com>
1 parent 08aa4c6 commit 7315aec

File tree

23 files changed

+594
-441
lines changed

23 files changed

+594
-441
lines changed

spring-ai-alibaba-integration-example/backend/src/main/java/com/alibaba/cloud/ai/application/service/SAAWebSearchService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ public Flux<String> chat(String prompt) {
8181
return chatClient.prompt()
8282
.advisors(
8383
createRetrievalAugmentationAdvisor(),
84-
reasoningContentAdvisor,
84+
// 不整合到 reasoning content 输出中
85+
// reasoningContentAdvisor,
8586
simpleLoggerAdvisor
8687
).user(prompt)
8788
.stream()

spring-ai-alibaba-integration-example/backend/src/main/java/com/alibaba/cloud/ai/application/websearch/config/WeSearchConfiguration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package com.alibaba.cloud.ai.application.websearch.config;
22

33
import com.alibaba.cloud.ai.application.websearch.rag.postretrieval.DashScopeDocumentRanker;
4+
import com.alibaba.cloud.ai.application.websearch.rag.preretrieval.query.expansion.MultiQueryExpander;
45
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
56
import com.alibaba.cloud.ai.model.RerankModel;
67

78
import org.springframework.ai.chat.client.ChatClient;
89
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
910
import org.springframework.ai.chat.prompt.PromptTemplate;
10-
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
1111
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
1212
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
1313
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;

spring-ai-alibaba-integration-example/backend/src/main/java/com/alibaba/cloud/ai/application/websearch/data/DataClean.java

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,27 @@ public List<Document> getData(GenericSearchResult respData) {
2929

3030
List<Document> documents = new ArrayList<>();
3131

32-
// 1. 获取 QueryContext 的 metadata
3332
Map<String, Object> metadata = getQueryMetadata(respData);
3433

3534
for (ScorePageItem pageItem : respData.getPageItems()) {
3635

37-
// 获取每个 pages 的 metadata
3836
Map<String, Object> pageItemMetadata = getPageItemMetadata(pageItem);
39-
// 获取 text
40-
String text = getText(pageItem);
41-
// 获取 media Document 限制,media 和 text 只能有一个
42-
// Media media = getMedia(pageItem);
43-
// 获取浏览器 score
4437
Double score = getScore(pageItem);
38+
String text = getText(pageItem);
39+
40+
if (Objects.equals("", text)) {
41+
42+
Media media = getMedia(pageItem);
43+
Document document = new Document.Builder()
44+
.metadata(metadata)
45+
.metadata(pageItemMetadata)
46+
.media(media)
47+
.score(score)
48+
.build();
49+
50+
documents.add(document);
51+
break;
52+
}
4553

4654
Document document = new Document.Builder()
4755
.metadata(metadata)
@@ -126,7 +134,18 @@ private Media getMedia(ScorePageItem pageItem) {
126134

127135
private String getText(ScorePageItem pageItem) {
128136

129-
return pageItem.getMainText();
137+
if (Objects.nonNull(pageItem.getMainText())) {
138+
139+
String mainText = pageItem.getMainText();
140+
141+
mainText = mainText.replaceAll("<[^>]+>", "");
142+
mainText = mainText.replaceAll("[\\n\\t\\r]+", " ");
143+
mainText = mainText.replaceAll("[\\u200B-\\u200D\\uFEFF]", "");
144+
145+
return mainText.trim();
146+
}
147+
148+
return "";
130149
}
131150

132151
public List<Document> limitResults(List<Document> documents, int minResults) {

spring-ai-alibaba-integration-example/backend/src/main/java/com/alibaba/cloud/ai/application/websearch/rag/WebSearchRetriever.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ public List<Document> retrieve(
4242

4343
// 搜索
4444
GenericSearchResult searchResp = searchEngine.search(query.text());
45-
logger.debug("search response: {}", searchResp);
4645

4746
// 清洗数据
4847
List<Document> cleanerData = dataCleaner.getData(searchResp);

spring-ai-alibaba-integration-example/backend/src/main/java/com/alibaba/cloud/ai/application/websearch/rag/join/ConcatenationDocumentJoiner.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
import java.util.Set;
1111
import java.util.stream.Collectors;
1212

13+
import org.jetbrains.annotations.NotNull;
1314
import org.slf4j.Logger;
1415
import org.slf4j.LoggerFactory;
1516

1617
import org.springframework.ai.document.Document;
1718
import org.springframework.ai.rag.Query;
1819
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
20+
import org.springframework.lang.Nullable;
1921
import org.springframework.util.Assert;
2022

2123
/**
@@ -27,8 +29,11 @@ public class ConcatenationDocumentJoiner implements DocumentJoiner {
2729

2830
private static final Logger logger = LoggerFactory.getLogger(ConcatenationDocumentJoiner.class);
2931

32+
@NotNull
3033
@Override
31-
public List<Document> join(Map<Query, List<List<Document>>> documentsForQuery) {
34+
public List<Document> join(
35+
@Nullable Map<Query, List<List<Document>>> documentsForQuery
36+
) {
3237

3338
Assert.notNull(documentsForQuery, "documentsForQuery cannot be null");
3439
Assert.noNullElements(documentsForQuery.keySet(), "documentsForQuery cannot contain null keys");

spring-ai-alibaba-integration-example/backend/src/main/java/com/alibaba/cloud/ai/application/websearch/rag/postretrieval/DashScopeDocumentRanker.java

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010
import com.alibaba.cloud.ai.model.RerankModel;
1111
import com.alibaba.cloud.ai.model.RerankRequest;
1212
import com.alibaba.cloud.ai.model.RerankResponse;
13+
import org.jetbrains.annotations.NotNull;
14+
import org.slf4j.Logger;
15+
import org.slf4j.LoggerFactory;
1316

1417
import org.springframework.ai.document.Document;
1518
import org.springframework.ai.rag.Query;
1619
import org.springframework.ai.rag.postretrieval.ranking.DocumentRanker;
20+
import org.springframework.lang.Nullable;
21+
import org.springframework.util.StringUtils;
1722

1823
/**
1924
* @author yuluo
@@ -22,14 +27,24 @@
2227

2328
public class DashScopeDocumentRanker implements DocumentRanker {
2429

30+
private static final Logger logger = LoggerFactory.getLogger(DashScopeDocumentRanker.class);
31+
2532
private final RerankModel rerankModel;
2633

2734
public DashScopeDocumentRanker(RerankModel rerankModel) {
2835
this.rerankModel = rerankModel;
2936
}
3037

38+
@NotNull
3139
@Override
32-
public List<Document> rank(Query query, List<Document> documents) {
40+
public List<Document> rank(
41+
@Nullable Query query,
42+
@Nullable List<Document> documents
43+
) {
44+
45+
if (documents == null || documents.isEmpty()) {
46+
return new ArrayList<>();
47+
}
3348

3449
try {
3550
List<Document> reorderDocs = new ArrayList<>();
@@ -39,27 +54,35 @@ public List<Document> rank(Query query, List<Document> documents) {
3954
.withTopN(documents.size())
4055
.build();
4156

42-
// 组装参数调用 rerankModel
43-
RerankRequest rerankRequest = new RerankRequest(
44-
query.text(),
45-
documents,
46-
rerankOptions
47-
);
48-
RerankResponse rerankResp = rerankModel.call(rerankRequest);
57+
if (Objects.nonNull(query) && StringUtils.hasText(query.text())) {
58+
// 组装参数调用 rerankModel
59+
RerankRequest rerankRequest = new RerankRequest(
60+
query.text(),
61+
documents,
62+
rerankOptions
63+
);
64+
RerankResponse rerankResp = rerankModel.call(rerankRequest);
4965

50-
rerankResp.getResults().forEach(res -> {
51-
Document outputDocs = res.getOutput();
66+
rerankResp.getResults().forEach(res -> {
67+
Document outputDocs = res.getOutput();
5268

53-
// 查找并添加到新的 list 中
54-
Optional<Document> foundDocsOptional = documents.stream()
55-
.filter(doc -> Objects.equals(doc.getId(), outputDocs.getId()))
56-
.findFirst();
69+
// 查找并添加到新的 list 中
70+
Optional<Document> foundDocsOptional = documents.stream()
71+
.filter(doc ->
72+
{
73+
// debug rerank output.
74+
logger.debug("doc id: {}, outputDocs id: {}", doc.getId(), outputDocs.getId());
75+
return Objects.equals(doc.getId(), outputDocs.getId());
76+
})
77+
.findFirst();
5778

58-
foundDocsOptional.ifPresent(reorderDocs::add);
59-
});
79+
foundDocsOptional.ifPresent(reorderDocs::add);
80+
});
81+
}
6082

6183
return reorderDocs;
62-
} catch (Exception e) {
84+
}
85+
catch (Exception e) {
6386
// 根据异常类型做进一步处理
6487
throw new SAAAppException(e.getMessage());
6588
}

spring-ai-alibaba-integration-example/backend/src/main/java/com/alibaba/cloud/ai/application/websearch/rag/preretrieval/query/expansion/MultiQueryExpander.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import java.util.Objects;
66
import java.util.stream.Collectors;
77

8+
import org.jetbrains.annotations.NotNull;
89
import org.slf4j.Logger;
910
import org.slf4j.LoggerFactory;
1011

@@ -45,7 +46,7 @@ public class MultiQueryExpander implements QueryExpander {
4546

4647
private static final Boolean DEFAULT_INCLUDE_ORIGINAL = true;
4748

48-
private static final Integer DEFAULT_NUMBER_OF_QUERIES = 5;
49+
private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3;
4950

5051
private final ChatClient chatClient;
5152

@@ -61,6 +62,7 @@ public MultiQueryExpander(
6162
@Nullable Boolean includeOriginal,
6263
@Nullable Integer numberOfQueries
6364
) {
65+
6466
Assert.notNull(chatClientBuilder, "ChatClient.Builder must not be null");
6567

6668
this.chatClient = chatClientBuilder.build();
@@ -71,8 +73,9 @@ public MultiQueryExpander(
7173
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query");
7274
}
7375

76+
@NotNull
7477
@Override
75-
public List<Query> expand(Query query) {
78+
public List<Query> expand(@Nullable Query query) {
7679

7780
Assert.notNull(query, "Query must not be null");
7881

@@ -95,7 +98,9 @@ public List<Query> expand(Query query) {
9598

9699
if (CollectionUtils.isEmpty(queryVariants) || this.numberOfQueries != queryVariants.size()) {
97100

98-
logger.warn("Query expansion result dose not contain the requested {} variants for query: {}. is return.", this.numberOfQueries, query.text());
101+
logger.warn("Query expansion result dose not contain the requested {} variants for query: {}. is return.",
102+
this.numberOfQueries, query.text());
103+
99104
return List.of(query);
100105
}
101106

@@ -109,10 +114,16 @@ public List<Query> expand(Query query) {
109114
logger.debug("Including original query in the expanded queries for query: {}", query.text());
110115
queries.add(0, query);
111116
}
117+
118+
logger.debug("Rewrite queries: {}", queries);
112119

113120
return queries;
114121
}
115122

123+
public static Builder builder() {
124+
return new Builder();
125+
}
126+
116127
public static final class Builder {
117128

118129
private ChatClient.Builder chatClientBuilder;

spring-ai-alibaba-integration-example/backend/src/main/java/com/alibaba/cloud/ai/application/websearch/rag/prompt/CustomContextQueryAugmenter.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import java.util.concurrent.atomic.AtomicInteger;
66
import java.util.stream.Collectors;
77

8+
import org.jetbrains.annotations.NotNull;
89
import org.slf4j.Logger;
910
import org.slf4j.LoggerFactory;
1011

@@ -71,18 +72,25 @@ public CustomContextQueryAugmenter(
7172
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "query", "context");
7273
}
7374

75+
@NotNull
7476
@Override
75-
public Query augment(Query query, List<Document> documents) {
77+
public Query augment(
78+
@Nullable Query query,
79+
@Nullable List<Document> documents
80+
) {
7681

7782
Assert.notNull(query, "Query must not be null");
7883
Assert.notNull(documents, "Documents must not be null");
7984

8085
logger.debug("Augmenting query: {}", query);
8186

8287
if (documents.isEmpty()) {
88+
logger.debug("No documents found. Augmenting query with empty context.");
8389
return augmentQueryWhenEmptyContext(query);
8490
}
8591

92+
logger.debug("Documents found. Augmenting query with context.");
93+
8694
// 1. collect content from documents.
8795
AtomicInteger idCounter = new AtomicInteger(1);
8896
String documentContext = documents.stream()

spring-ai-alibaba-integration-example/backend/src/main/resources/application-dev.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ spring:
4242
# enabled debug log out.
4343
logging:
4444
level:
45-
org.springframework.ai: debug
46-
com.alibaba.dashscope.api: debug
45+
# org.springframework.ai: debug
46+
# com.alibaba.dashscope.api: debug
4747
com.alibaba.cloud.ai.application.websearch: debug
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
http://localhost:8080/api/v1/search?query="杭州有什么推荐旅游的地方吗?"

0 commit comments

Comments
 (0)