[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(improvement)(chat) Obtain similar query from ExemplarService instead of directly from embedding store #1278

Merged
merged 1 commit into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
(improvement)(chat) Obtain similar query from ExemplarService instead…
… of directly from embedding store
  • Loading branch information
lxwcodemonkey committed Jun 29, 2024
commit 3c51b6784b7a495b9510529e28dfa1916921d872
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
import com.baomidou.mybatisplus.annotation.TableName;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;

import java.util.Date;

@Data
@Builder
@ToString
@AllArgsConstructor
@NoArgsConstructor
@TableName("s2_chat_memory")
public class ChatMemoryDO {
@TableId(type = IdType.AUTO)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,18 @@

import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.github.pagehelper.PageInfo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.SimilarQueryManager;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

Expand All @@ -44,25 +39,10 @@ private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) {
}

public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
//1. recall solved query by queryText
SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
List<SimilarQueryRecallResp> similarQueries = solvedQueryManager.recallSimilarQuery(queryText, agentId);
if (CollectionUtils.isEmpty(similarQueries)) {
return Lists.newArrayList();
}
//2. remove low score query
List<Long> queryIds = similarQueries.stream()
.map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList());
int lowScoreThreshold = 3;
List<QueryResp> queryResps = getChatQuery(queryIds);
if (CollectionUtils.isEmpty(queryResps)) {
return Lists.newArrayList();
}
Set<Long> lowScoreQueryIds = queryResps.stream().filter(queryResp ->
queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold)
.map(QueryResp::getQuestionId).collect(Collectors.toSet());
return similarQueries.stream().filter(solvedQueryRecallResp ->
!lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId()))
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
List<SqlExemplar> exemplars = exemplarService.recallExemplars(agentId.toString(), queryText, 5);
return exemplars.stream().map(sqlExemplar ->
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
.collect(Collectors.toList());
}

Expand All @@ -71,16 +51,6 @@ private ChatQueryDO getChatQuery(Long queryId) {
return chatQueryRepository.getChatQueryDO(queryId);
}

private List<QueryResp> getChatQuery(List<Long> queryIds) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
pageQueryInfoReq.setIds(queryIds);
pageQueryInfoReq.setPageSize(100);
pageQueryInfoReq.setCurrent(1);
PageInfo<QueryResp> queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, null);
return queryRespPageInfo.getList();
}

private void updateChatQuery(ChatQueryDO chatQueryDO) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
UpdateWrapper<ChatQueryDO> updateWrapper = new UpdateWrapper<>();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ public class EmbeddingConfig {
@Value("${s2.embedding.nResult:1}")
private int nResult;

@Value("${s2.embedding.solved.query.collection:solved_query_collection}")
private String solvedQueryCollection;

@Value("${s2.embedding.solved.query.nResult:5}")
private int solvedQueryResultNum;

@Value("${s2.embedding.metric.analyzeQuery.collection:solved_query_collection}")
private String metricAnalyzeQueryCollection;

Expand Down
Loading