diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 9fe134ed6..3ead33fb1 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -87,6 +87,7 @@ import ( config2 "github.com/apache/answer/internal/service/config" "github.com/apache/answer/internal/service/content" "github.com/apache/answer/internal/service/dashboard" + "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/eventqueue" export2 "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/feature_toggle" @@ -283,7 +284,8 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, apiKeyService := apikey.NewAPIKeyService(apiKeyRepo) adminAPIKeyController := controller_admin.NewAdminAPIKeyController(apiKeyService) featureToggleService := feature_toggle.NewFeatureToggleService(siteInfoRepo) - mcpController := controller.NewMCPController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, featureToggleService) + embeddingService := embedding.NewEmbeddingService() + mcpController := controller.NewMCPController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, featureToggleService, embeddingService) aiConversationRepo := ai_conversation.NewAIConversationRepo(dataData) aiConversationService := ai_conversation2.NewAIConversationService(aiConversationRepo, userCommon) aiController := controller.NewAIController(searchService, siteInfoCommonService, tagCommonService, questionCommon, commentRepo, userCommon, answerRepo, mcpController, aiConversationService, featureToggleService) diff --git a/internal/base/constant/ai_config.go b/internal/base/constant/ai_config.go index aa733bbaf..a25e47a45 100644 --- a/internal/base/constant/ai_config.go +++ b/internal/base/constant/ai_config.go @@ -33,6 +33,7 @@ const ( - get_tags: 搜索标签信息 - get_tag_detail: 获取特定标签的详细信息 - get_user: 搜索用户信息 +- semantic_search: 通过语义相似度搜索问题和答案。当用户的问题与现有内容概念相关但可能不匹配确切关键词时使用此工具。当 get_questions 关键词搜索返回较差结果时,请使用 semantic_search。 请根据用户的问题智能地使用这些工具来提供准确的答案。如果需要查询系统信息,请先使用相应的工具获取数据。` DefaultAIPromptConfigEnUS = `You are an intelligent assistant that can help users query information in the system. User question: %s @@ -44,6 +45,7 @@ You can use the following tools to query system information: - get_tags: Search for tag information - get_tag_detail: Get detailed information about a specific tag - get_user: Search for user information +- semantic_search: Search questions and answers by semantic meaning. Use this when the user's question relates conceptually to existing content but may not match exact keywords. When get_questions keyword search returns poor results, use semantic_search instead. Please intelligently use these tools based on the user's question to provide accurate answers. If you need to query system information, please use the appropriate tools to get the data first.` ) diff --git a/internal/controller/ai_controller.go b/internal/controller/ai_controller.go index e020ed30e..125cdab22 100644 --- a/internal/controller/ai_controller.go +++ b/internal/controller/ai_controller.go @@ -446,6 +446,7 @@ func (c *AIController) handleAIConversation(ctx *gin.Context, w http.ResponseWri toolCalls, newMessages, finished, aiResponse := c.processAIStream(ctx, w, id, conversationCtx.Model, client, aiReq, messages) messages = newMessages + log.Debugf("Round %d: toolCalls=%v", round+1, toolCalls) if aiResponse != "" { conversationCtx.Messages = append(conversationCtx.Messages, &ai_conversation.ConversationMessage{ Role: "assistant", @@ -497,6 +498,10 @@ func (c *AIController) processAIStream( break } + if len(response.Choices) == 0 { + continue + } + choice := response.Choices[0] if len(choice.Delta.ToolCalls) > 0 { @@ -735,6 +740,8 @@ func (c *AIController) callMCPTool(ctx context.Context, toolName string, argumen result, err = c.mcpController.MCPTagDetailsHandler()(ctx, request) case "get_user": result, err = c.mcpController.MCPUserDetailsHandler()(ctx, request) + case "semantic_search": + result, err = c.mcpController.MCPSemanticSearchHandler()(ctx, request) default: return "", fmt.Errorf("unknown tool: %s", toolName) } diff --git a/internal/controller/mcp_controller.go b/internal/controller/mcp_controller.go index d52f57979..b40c58cf4 100644 --- a/internal/controller/mcp_controller.go +++ b/internal/controller/mcp_controller.go @@ -31,11 +31,13 @@ import ( answercommon "github.com/apache/answer/internal/service/answer_common" "github.com/apache/answer/internal/service/comment" "github.com/apache/answer/internal/service/content" + "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/feature_toggle" questioncommon "github.com/apache/answer/internal/service/question_common" "github.com/apache/answer/internal/service/siteinfo_common" tagcommonser "github.com/apache/answer/internal/service/tag_common" usercommon "github.com/apache/answer/internal/service/user_common" + "github.com/apache/answer/plugin" "github.com/mark3labs/mcp-go/mcp" "github.com/segmentfault/pacman/log" ) @@ -49,6 +51,7 @@ type MCPController struct { userCommon *usercommon.UserCommon answerRepo answercommon.AnswerRepo featureToggleSvc *feature_toggle.FeatureToggleService + embeddingService *embedding.EmbeddingService } // NewMCPController new site info controller. @@ -61,6 +64,7 @@ func NewMCPController( userCommon *usercommon.UserCommon, answerRepo answercommon.AnswerRepo, featureToggleSvc *feature_toggle.FeatureToggleService, + embeddingService *embedding.EmbeddingService, ) *MCPController { return &MCPController{ searchService: searchService, @@ -71,6 +75,7 @@ func NewMCPController( userCommon: userCommon, answerRepo: answerRepo, featureToggleSvc: featureToggleSvc, + embeddingService: embeddingService, } } @@ -349,3 +354,131 @@ func (c *MCPController) MCPUserDetailsHandler() func(ctx context.Context, reques return mcp.NewToolResultText(string(res)), nil } } + +func (c *MCPController) MCPSemanticSearchHandler() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := c.ensureMCPEnabled(ctx); err != nil { + return nil, err + } + cond := schema.NewMCPSemanticSearchCond(request) + if len(cond.Query) == 0 { + return mcp.NewToolResultText("Query is required for semantic search."), nil + } + + siteGeneral, err := c.siteInfoService.GetSiteGeneral(ctx) + if err != nil { + log.Errorf("get site general info failed: %v", err) + return nil, err + } + + results, err := c.embeddingService.SearchSimilar(ctx, cond.Query, cond.TopK) + if err != nil { + log.Errorf("semantic search failed: %v", err) + return mcp.NewToolResultText("Semantic search is not available. Embedding may not be configured."), nil + } + if len(results) == 0 { + return mcp.NewToolResultText("No semantically similar content found."), nil + } + + resp := make([]*schema.MCPSemanticSearchResp, 0, len(results)) + for _, r := range results { + var meta plugin.VectorSearchMetadata + _ = json.Unmarshal([]byte(r.Metadata), &meta) + + item := &schema.MCPSemanticSearchResp{ + ObjectID: r.ObjectID, + ObjectType: r.ObjectType, + Score: r.Score, + } + + // Compose link from metadata + if r.ObjectType == "answer" && meta.AnswerID != "" { + item.Link = fmt.Sprintf("%s/questions/%s/%s", siteGeneral.SiteUrl, meta.QuestionID, meta.AnswerID) + } else { + item.Link = fmt.Sprintf("%s/questions/%s", siteGeneral.SiteUrl, meta.QuestionID) + } + + // Query content from DB using IDs stored in metadata + if r.ObjectType == "question" { + question, qErr := c.questioncommon.Info(ctx, meta.QuestionID, "") + if qErr != nil { + log.Warnf("get question %s for semantic search failed: %v", meta.QuestionID, qErr) + } else { + item.Title = question.Title + item.Content = question.Content + } + + // Fetch answers by ID from metadata + for _, a := range meta.Answers { + answerEntity, exist, aErr := c.answerRepo.GetAnswer(ctx, a.AnswerID) + if aErr != nil || !exist { + continue + } + answerItem := &schema.MCPSemanticSearchAnswer{ + AnswerID: a.AnswerID, + Content: answerEntity.OriginalText, + } + // Fetch comments on this answer from DB + for _, ac := range a.Comments { + cmt, cExist, cErr := c.commentRepo.GetComment(ctx, ac.CommentID) + if cErr == nil && cExist { + answerItem.Comments = append(answerItem.Comments, &schema.MCPSemanticSearchComment{ + CommentID: ac.CommentID, + Content: cmt.OriginalText, + }) + } + } + item.Answers = append(item.Answers, answerItem) + } + + // Fetch question comments from DB + for _, qc := range meta.Comments { + cmt, cExist, cErr := c.commentRepo.GetComment(ctx, qc.CommentID) + if cErr == nil && cExist { + item.Comments = append(item.Comments, &schema.MCPSemanticSearchComment{ + CommentID: qc.CommentID, + Content: cmt.OriginalText, + }) + } + } + } else if r.ObjectType == "answer" { + // Fetch question title for context + question, qErr := c.questioncommon.Info(ctx, meta.QuestionID, "") + if qErr == nil { + item.Title = question.Title + } + + // Fetch answer content from DB + if meta.AnswerID != "" { + answerEntity, exist, aErr := c.answerRepo.GetAnswer(ctx, meta.AnswerID) + if aErr == nil && exist { + item.Content = answerEntity.OriginalText + } + } else if len(meta.Answers) > 0 { + answerEntity, exist, aErr := c.answerRepo.GetAnswer(ctx, meta.Answers[0].AnswerID) + if aErr == nil && exist { + item.Content = answerEntity.OriginalText + } + } + + // Fetch answer comments from DB + if len(meta.Answers) > 0 { + for _, ac := range meta.Answers[0].Comments { + cmt, cExist, cErr := c.commentRepo.GetComment(ctx, ac.CommentID) + if cErr == nil && cExist { + item.Comments = append(item.Comments, &schema.MCPSemanticSearchComment{ + CommentID: ac.CommentID, + Content: cmt.OriginalText, + }) + } + } + } + } + + resp = append(resp, item) + } + + data, _ := json.Marshal(resp) + return mcp.NewToolResultText(string(data)), nil + } +} diff --git a/internal/repo/vector_search_sync/syncer.go b/internal/repo/vector_search_sync/syncer.go new file mode 100644 index 000000000..f27b1fd2c --- /dev/null +++ b/internal/repo/vector_search_sync/syncer.go @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package vector_search_sync + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/apache/answer/internal/base/data" + "github.com/apache/answer/internal/entity" + "github.com/apache/answer/pkg/uid" + "github.com/apache/answer/plugin" + "github.com/segmentfault/pacman/log" +) + +// NewPluginSyncer creates a new VectorSearchSyncer that reads from the database. +func NewPluginSyncer(data *data.Data) plugin.VectorSearchSyncer { + return &PluginSyncer{data: data} +} + +// PluginSyncer implements plugin.VectorSearchSyncer. +// It aggregates question/answer text with comments for vector embedding. +type PluginSyncer struct { + data *data.Data +} + +// GetQuestionsPage returns a page of questions with aggregated text +// (question title + body + all answers + all comments). +func (p *PluginSyncer) GetQuestionsPage(ctx context.Context, page, pageSize int) ( + []*plugin.VectorSearchContent, error) { + questions := make([]*entity.Question, 0) + startNum := (page - 1) * pageSize + err := p.data.DB.Context(ctx).Limit(pageSize, startNum).Find(&questions) + if err != nil { + return nil, err + } + return p.buildQuestionContents(ctx, questions) +} + +// GetAnswersPage returns a page of answers with aggregated text +// (parent question title + answer body + answer comments). +func (p *PluginSyncer) GetAnswersPage(ctx context.Context, page, pageSize int) ( + []*plugin.VectorSearchContent, error) { + answers := make([]*entity.Answer, 0) + startNum := (page - 1) * pageSize + err := p.data.DB.Context(ctx).Limit(pageSize, startNum).Find(&answers) + if err != nil { + return nil, err + } + return p.buildAnswerContents(ctx, answers) +} + +// buildQuestionContents aggregates each question with its answers and comments. +func (p *PluginSyncer) buildQuestionContents(ctx context.Context, questions []*entity.Question) ( + []*plugin.VectorSearchContent, error) { + result := make([]*plugin.VectorSearchContent, 0, len(questions)) + for _, q := range questions { + meta := plugin.VectorSearchMetadata{ + QuestionID: uid.DeShortID(q.ID), + } + + var parts []string + parts = append(parts, fmt.Sprintf("Question: %s\n%s", q.Title, q.OriginalText)) + + // Get answers for this question + answers := make([]*entity.Answer, 0) + err := p.data.DB.Context(ctx).Where("question_id = ?", q.ID).Find(&answers) + if err != nil { + log.Warnf("get answers for question %s failed: %v", q.ID, err) + } else { + for _, a := range answers { + parts = append(parts, fmt.Sprintf("Answer: %s", a.OriginalText)) + answerMeta := plugin.VectorSearchMetadataAnswer{ + AnswerID: uid.DeShortID(a.ID), + } + + // Get comments on this answer + answerComments := make([]*entity.Comment, 0) + err := p.data.DB.Context(ctx).Where("object_id = ?", a.ID). + OrderBy("created_at ASC").Limit(50).Find(&answerComments) + if err != nil { + log.Warnf("get comments for answer %s failed: %v", a.ID, err) + } else { + for _, c := range answerComments { + parts = append(parts, fmt.Sprintf("Comment on answer: %s", c.OriginalText)) + answerMeta.Comments = append(answerMeta.Comments, plugin.VectorSearchMetadataComment{ + CommentID: uid.DeShortID(c.ID), + }) + } + } + meta.Answers = append(meta.Answers, answerMeta) + } + } + + // Get comments on the question + questionComments := make([]*entity.Comment, 0) + err = p.data.DB.Context(ctx).Where("object_id = ?", q.ID). + OrderBy("created_at ASC").Limit(50).Find(&questionComments) + if err != nil { + log.Warnf("get comments for question %s failed: %v", q.ID, err) + } else { + for _, c := range questionComments { + parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) + meta.Comments = append(meta.Comments, plugin.VectorSearchMetadataComment{ + CommentID: uid.DeShortID(c.ID), + }) + } + } + + metaJSON, _ := json.Marshal(meta) + result = append(result, &plugin.VectorSearchContent{ + ObjectID: uid.DeShortID(q.ID), + ObjectType: "question", + Title: q.Title, + Content: strings.Join(parts, "\n\n"), + Metadata: string(metaJSON), + }) + } + return result, nil +} + +// buildAnswerContents aggregates each answer with its parent question title and comments. +func (p *PluginSyncer) buildAnswerContents(ctx context.Context, answers []*entity.Answer) ( + []*plugin.VectorSearchContent, error) { + result := make([]*plugin.VectorSearchContent, 0, len(answers)) + for _, a := range answers { + // Get parent question for title + question := &entity.Question{} + exist, err := p.data.DB.Context(ctx).Where("id = ?", a.QuestionID).Get(question) + if err != nil { + log.Errorf("get question %s failed: %v", a.QuestionID, err) + continue + } + if !exist { + continue + } + + meta := plugin.VectorSearchMetadata{ + QuestionID: uid.DeShortID(a.QuestionID), + AnswerID: uid.DeShortID(a.ID), + } + + var parts []string + parts = append(parts, fmt.Sprintf("Question: %s", question.Title)) + parts = append(parts, fmt.Sprintf("Answer: %s", a.OriginalText)) + + answerMeta := plugin.VectorSearchMetadataAnswer{ + AnswerID: uid.DeShortID(a.ID), + } + + // Get comments on this answer + answerComments := make([]*entity.Comment, 0) + err = p.data.DB.Context(ctx).Where("object_id = ?", a.ID). + OrderBy("created_at ASC").Limit(50).Find(&answerComments) + if err != nil { + log.Warnf("get comments for answer %s failed: %v", a.ID, err) + } else { + for _, c := range answerComments { + parts = append(parts, fmt.Sprintf("Comment: %s", c.OriginalText)) + answerMeta.Comments = append(answerMeta.Comments, plugin.VectorSearchMetadataComment{ + CommentID: uid.DeShortID(c.ID), + }) + } + } + meta.Answers = append(meta.Answers, answerMeta) + + metaJSON, _ := json.Marshal(meta) + result = append(result, &plugin.VectorSearchContent{ + ObjectID: uid.DeShortID(a.ID), + ObjectType: "answer", + Title: question.Title, + Content: strings.Join(parts, "\n\n"), + Metadata: string(metaJSON), + }) + } + return result, nil +} diff --git a/internal/schema/mcp_schema.go b/internal/schema/mcp_schema.go index bead21c9d..9afee72ec 100644 --- a/internal/schema/mcp_schema.go +++ b/internal/schema/mcp_schema.go @@ -27,15 +27,17 @@ import ( ) const ( - MCPSearchCondKeyword = "keyword" - MCPSearchCondUsername = "username" - MCPSearchCondScore = "score" - MCPSearchCondTag = "tag" - MCPSearchCondPage = "page" - MCPSearchCondPageSize = "page_size" - MCPSearchCondTagName = "tag_name" - MCPSearchCondQuestionID = "question_id" - MCPSearchCondObjectID = "object_id" + MCPSearchCondKeyword = "keyword" + MCPSearchCondUsername = "username" + MCPSearchCondScore = "score" + MCPSearchCondTag = "tag" + MCPSearchCondPage = "page" + MCPSearchCondPageSize = "page_size" + MCPSearchCondTagName = "tag_name" + MCPSearchCondQuestionID = "question_id" + MCPSearchCondObjectID = "object_id" + MCPSearchCondSemanticQuery = "query" + MCPSearchCondTopK = "top_k" ) type MCPSearchCond struct { @@ -98,6 +100,48 @@ type MCPSearchCommentInfoResp struct { Link string `json:"link"` } +// MCPSemanticSearchCond is the condition for semantic search. +type MCPSemanticSearchCond struct { + Query string `json:"query"` + TopK int `json:"top_k"` +} + +// MCPSemanticSearchResp is a single semantic search result. +type MCPSemanticSearchResp struct { + ObjectID string `json:"object_id"` + ObjectType string `json:"object_type"` + Title string `json:"title"` + Content string `json:"content"` + Score float64 `json:"score"` + Link string `json:"link"` + Answers []*MCPSemanticSearchAnswer `json:"answers,omitempty"` + Comments []*MCPSemanticSearchComment `json:"comments,omitempty"` +} + +// MCPSemanticSearchAnswer is an answer in a semantic search result. +type MCPSemanticSearchAnswer struct { + AnswerID string `json:"answer_id"` + Content string `json:"content"` + Comments []*MCPSemanticSearchComment `json:"comments,omitempty"` +} + +// MCPSemanticSearchComment is a comment in a semantic search result. +type MCPSemanticSearchComment struct { + CommentID string `json:"comment_id"` + Content string `json:"content"` +} + +func NewMCPSemanticSearchCond(request mcp.CallToolRequest) *MCPSemanticSearchCond { + cond := &MCPSemanticSearchCond{TopK: 5} + if query, ok := getRequestValue(request, MCPSearchCondSemanticQuery); ok { + cond.Query = query + } + if topK, ok := getRequestNumber(request, MCPSearchCondTopK); ok && topK > 0 { + cond.TopK = topK + } + return cond +} + func NewMCPSearchCond(request mcp.CallToolRequest) *MCPSearchCond { cond := &MCPSearchCond{} if keyword, ok := getRequestValue(request, MCPSearchCondKeyword); ok { diff --git a/internal/schema/mcp_tools/mcp_tools.go b/internal/schema/mcp_tools/mcp_tools.go index 949a738c7..3ae6b3bea 100644 --- a/internal/schema/mcp_tools/mcp_tools.go +++ b/internal/schema/mcp_tools/mcp_tools.go @@ -32,6 +32,7 @@ var ( NewTagsTool(), NewTagDetailTool(), NewUserTool(), + NewSemanticSearchTool(), } ) @@ -103,3 +104,17 @@ func NewUserTool() mcp.Tool { ) return listFilesTool } + +func NewSemanticSearchTool() mcp.Tool { + tool := mcp.NewTool("semantic_search", + mcp.WithDescription("Search questions and answers by semantic meaning. Use this when the user's question relates conceptually to existing content but may not match exact keywords. Returns the most semantically similar content."), + mcp.WithString(schema.MCPSearchCondSemanticQuery, + mcp.Required(), + mcp.Description("The search query text to find semantically similar questions and answers"), + ), + mcp.WithNumber(schema.MCPSearchCondTopK, + mcp.Description("Maximum number of results to return (default 5)"), + ), + ) + return tool +} diff --git a/internal/service/embedding/embedding_service.go b/internal/service/embedding/embedding_service.go new file mode 100644 index 000000000..c69d60d8e --- /dev/null +++ b/internal/service/embedding/embedding_service.go @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package embedding + +import ( + "context" + "fmt" + + "github.com/apache/answer/plugin" +) + +// EmbeddingService is a thin facade that delegates semantic search to a VectorSearch plugin. +// If no plugin is enabled, semantic search is unavailable. +type EmbeddingService struct{} + +// NewEmbeddingService creates a new EmbeddingService. +func NewEmbeddingService() *EmbeddingService { + return &EmbeddingService{} +} + +// SearchSimilar delegates to the VectorSearch plugin. +// Returns an error if no plugin is enabled. +func (s *EmbeddingService) SearchSimilar(ctx context.Context, query string, topK int) ([]plugin.VectorSearchResult, error) { + var results []plugin.VectorSearchResult + var searchErr error + found := false + + err := plugin.CallVectorSearch(func(vs plugin.VectorSearch) error { + found = true + results, searchErr = vs.SearchSimilar(ctx, query, topK) + return nil + }) + if err != nil { + return nil, fmt.Errorf("call vector search plugin failed: %w", err) + } + if !found { + return nil, fmt.Errorf("semantic search is not available: no vector search plugin is enabled") + } + if searchErr != nil { + return nil, searchErr + } + return results, nil +} diff --git a/internal/service/plugin_common/plugin_common_service.go b/internal/service/plugin_common/plugin_common_service.go index eb46b5ac2..b3674a230 100644 --- a/internal/service/plugin_common/plugin_common_service.go +++ b/internal/service/plugin_common/plugin_common_service.go @@ -25,6 +25,7 @@ import ( "github.com/apache/answer/internal/base/data" "github.com/apache/answer/internal/repo/search_sync" + "github.com/apache/answer/internal/repo/vector_search_sync" "github.com/segmentfault/pacman/errors" "github.com/segmentfault/pacman/log" @@ -103,6 +104,12 @@ func (ps *PluginCommonService) UpdatePluginConfig(ctx context.Context, req *sche } return nil }) + _ = plugin.CallVectorSearch(func(vs plugin.VectorSearch) error { + if vs.Info().SlugName == req.PluginSlugName { + vs.RegisterSyncer(ctx, vector_search_sync.NewPluginSyncer(ps.data)) + } + return nil + }) _ = plugin.CallImporter(func(importer plugin.Importer) error { importer.RegisterImporterFunc(ctx, ps.importerService.NewImporterFunc()) return nil @@ -176,6 +183,16 @@ func (ps *PluginCommonService) initPluginData() { }) } + // register syncers for search and vector search plugins on startup + _ = plugin.CallSearch(func(search plugin.Search) error { + search.RegisterSyncer(context.Background(), search_sync.NewPluginSyncer(ps.data)) + return nil + }) + _ = plugin.CallVectorSearch(func(vs plugin.VectorSearch) error { + vs.RegisterSyncer(context.Background(), vector_search_sync.NewPluginSyncer(ps.data)) + return nil + }) + // init plugin user config plugin.RegisterGetPluginUserConfigFunc(func(userID, pluginSlugName string) []byte { pluginUserConfig, exist, err := ps.pluginUserConfigRepo.GetPluginUserConfig(context.Background(), userID, pluginSlugName) diff --git a/internal/service/provider.go b/internal/service/provider.go index 3e43b0ae0..26f1c4309 100644 --- a/internal/service/provider.go +++ b/internal/service/provider.go @@ -36,6 +36,7 @@ import ( "github.com/apache/answer/internal/service/config" "github.com/apache/answer/internal/service/content" "github.com/apache/answer/internal/service/dashboard" + "github.com/apache/answer/internal/service/embedding" "github.com/apache/answer/internal/service/eventqueue" "github.com/apache/answer/internal/service/export" "github.com/apache/answer/internal/service/feature_toggle" @@ -134,4 +135,5 @@ var ProviderSetService = wire.NewSet( apikey.NewAPIKeyService, ai_conversation.NewAIConversationService, feature_toggle.NewFeatureToggleService, + embedding.NewEmbeddingService, ) diff --git a/internal/service/siteinfo/siteinfo_service.go b/internal/service/siteinfo/siteinfo_service.go index 1e25cbaa4..c29051b43 100644 --- a/internal/service/siteinfo/siteinfo_service.go +++ b/internal/service/siteinfo/siteinfo_service.go @@ -63,7 +63,6 @@ func NewSiteInfoService( configService *config.ConfigService, questioncommon *questioncommon.QuestionCommon, fileRecordService *file_record.FileRecordService, - ) *SiteInfoService { plugin.RegisterGetSiteURLFunc(func() string { generalSiteInfo, err := siteInfoCommonService.GetSiteGeneral(context.Background()) @@ -409,7 +408,11 @@ func (s *SiteInfoService) SaveSiteAI(ctx context.Context, req *schema.SiteAIReq) Content: string(content), Status: 1, } - return s.siteInfoRepo.SaveByType(ctx, constant.SiteTypeAI, siteInfo) + if err := s.siteInfoRepo.SaveByType(ctx, constant.SiteTypeAI, siteInfo); err != nil { + return err + } + + return nil } func (s *SiteInfoService) maskAIKeys(resp *schema.SiteAIResp) { diff --git a/plugin/plugin.go b/plugin/plugin.go index 8778b1625..3a657fc40 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -130,6 +130,10 @@ func Register(p Base) { if _, ok := p.(Sidebar); ok { registerSidebar(p.(Sidebar)) } + + if _, ok := p.(VectorSearch); ok { + registerVectorSearch(p.(VectorSearch)) + } } type Stack[T Base] struct { diff --git a/plugin/vector_search.go b/plugin/vector_search.go new file mode 100644 index 000000000..134247d6c --- /dev/null +++ b/plugin/vector_search.go @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package plugin + +import ( + "context" + "fmt" + "strings" + + "github.com/sashabaranov/go-openai" + "github.com/segmentfault/pacman/log" +) + +// VectorSearchResult holds a single similarity search result returned by a VectorSearch plugin. +type VectorSearchResult struct { + // ObjectID is the unique identifier of the matched object (question ID or answer ID). + ObjectID string `json:"object_id"` + // ObjectType is "question" or "answer". + ObjectType string `json:"object_type"` + // Metadata is a JSON string containing VectorSearchMetadata for link composition and content retrieval. + Metadata string `json:"metadata"` + // Score is the cosine similarity score (0-1). + Score float64 `json:"score"` +} + +// VectorSearchContent is the document structure passed to plugins for indexing. +type VectorSearchContent struct { + // ObjectID is the unique identifier (question ID or answer ID). + ObjectID string `json:"objectID"` + // ObjectType is "question" or "answer". + ObjectType string `json:"objectType"` + // Title is the question title. + Title string `json:"title"` + // Content is the aggregated text to be embedded (question body + answers + comments). + Content string `json:"content"` + // Metadata is a JSON string containing VectorSearchMetadata. + Metadata string `json:"metadata"` +} + +// VectorSearchDesc describes the vector search engine for display purposes. +type VectorSearchDesc struct { + // Icon is an SVG icon for display. Optional. + Icon string `json:"icon"` + // Link is the URL of the vector search engine. Optional. + Link string `json:"link"` +} + +// VectorSearchMetadata holds IDs for URI composition and content retrieval at query time. +// Shared between plugins and the core MCP controller. +type VectorSearchMetadata struct { + QuestionID string `json:"question_id"` + AnswerID string `json:"answer_id,omitempty"` + Answers []VectorSearchMetadataAnswer `json:"answers,omitempty"` + Comments []VectorSearchMetadataComment `json:"comments,omitempty"` +} + +// VectorSearchMetadataAnswer stores answer ID and its comment IDs in metadata. +type VectorSearchMetadataAnswer struct { + AnswerID string `json:"answer_id"` + Comments []VectorSearchMetadataComment `json:"comments,omitempty"` +} + +// VectorSearchMetadataComment stores a comment ID in metadata. +type VectorSearchMetadataComment struct { + CommentID string `json:"comment_id"` +} + +// VectorSearch is the plugin interface for vector/semantic search engines. +// Plugins implementing this interface manage their own vector storage, embedding computation, +// data synchronization schedule, and similarity search. +type VectorSearch interface { + Base + + // Description returns metadata about the vector search engine. + Description() VectorSearchDesc + + // RegisterSyncer is called by the core to provide a data syncer. + // The plugin should store the syncer and use it to bulk-sync content + // (typically in a background goroutine). + RegisterSyncer(ctx context.Context, syncer VectorSearchSyncer) + + // SearchSimilar performs a semantic similarity search. + // The plugin is responsible for embedding the query text and searching its vector store. + // Returns up to topK results sorted by similarity score (descending). + SearchSimilar(ctx context.Context, query string, topK int) ([]VectorSearchResult, error) + + // UpdateContent upserts a single document in the vector store. + // Called by the core on incremental content changes. + UpdateContent(ctx context.Context, content *VectorSearchContent) error + + // DeleteContent removes a document from the vector store by object ID. + DeleteContent(ctx context.Context, objectID string) error +} + +// VectorSearchSyncer is implemented by the core and provided to plugins via RegisterSyncer. +// Plugins call these methods to pull all content for bulk indexing. +type VectorSearchSyncer interface { + // GetQuestionsPage returns a page of questions with aggregated text (title + body + answers + comments). + GetQuestionsPage(ctx context.Context, page, pageSize int) ([]*VectorSearchContent, error) + // GetAnswersPage returns a page of answers with aggregated text (answer body + parent question title + comments). + GetAnswersPage(ctx context.Context, page, pageSize int) ([]*VectorSearchContent, error) +} + +var ( + // CallVectorSearch is a function that calls all registered VectorSearch plugins. + CallVectorSearch, + registerVectorSearch = MakePlugin[VectorSearch](false) +) + +// GenerateEmbedding is a base utility function that generates an embedding vector +// using an OpenAI-compatible API. Plugins that don't have a built-in vectorizer +// (most vector databases) can call this function with their own credentials. +// Plugins with built-in vectorizers (e.g., Weaviate) can skip this and use their own. +// +// Parameters: +// - ctx: context for cancellation +// - apiHost: the API base URL (e.g. "https://api.openai.com"); "/v1" is appended if missing +// - apiKey: the API key for authentication +// - model: the embedding model name (e.g. "text-embedding-3-small") +// - text: the text to embed +// +// Returns the embedding vector as []float32, or an error. +func GenerateEmbedding(ctx context.Context, apiHost, apiKey, model, text string) ([]float32, error) { + if model == "" { + return nil, fmt.Errorf("embedding model is not configured") + } + if text == "" { + return nil, fmt.Errorf("text is empty") + } + + config := openai.DefaultConfig(apiKey) + config.BaseURL = apiHost + if !strings.HasSuffix(config.BaseURL, "/v1") { + config.BaseURL += "/v1" + } + + log.Debugf("embedding: requesting model=%s baseURL=%s textLen=%d", model, config.BaseURL, len(text)) + + client := openai.NewClientWithConfig(config) + + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequestStrings{ + Input: []string{text}, + Model: openai.EmbeddingModel(model), + }) + if err != nil { + log.Errorf("embedding: request failed model=%s baseURL=%s err=%v", model, config.BaseURL, err) + return nil, fmt.Errorf("create embeddings failed: %w", err) + } + if len(resp.Data) == 0 { + log.Errorf("embedding: no data returned model=%s baseURL=%s", model, config.BaseURL) + return nil, fmt.Errorf("no embedding returned") + } + + log.Debugf("embedding: success model=%s dimensions=%d usage={prompt=%d,total=%d}", + model, len(resp.Data[0].Embedding), resp.Usage.PromptTokens, resp.Usage.TotalTokens) + return resp.Data[0].Embedding, nil +}