diff --git a/.github/workflows/deploy-image-arm.yml b/.github/workflows/deploy-image-arm.yml index 163b7dc76..9721addb3 100644 --- a/.github/workflows/deploy-image-arm.yml +++ b/.github/workflows/deploy-image-arm.yml @@ -19,6 +19,7 @@ env: jobs: build-and-push-image: + if: github.repository == 'zhayujie/chatgpt-on-wechat' runs-on: ubuntu-latest permissions: contents: read diff --git a/.github/workflows/deploy-image.yml b/.github/workflows/deploy-image.yml index c3c843966..a30b77ffe 100644 --- a/.github/workflows/deploy-image.yml +++ b/.github/workflows/deploy-image.yml @@ -19,6 +19,7 @@ env: jobs: build-and-push-image: + if: github.repository == 'zhayujie/chatgpt-on-wechat' runs-on: ubuntu-latest permissions: contents: read diff --git a/.gitignore b/.gitignore index 9b3bcdf0f..560e6151a 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,5 @@ plugins/banwords/lib/__pycache__ !plugins/hello !plugins/role !plugins/keyword -!plugins/linkai \ No newline at end of file +!plugins/linkai +client_config.json diff --git a/README.md b/README.md index f46c35786..f2ca5db09 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,68 @@ # 简介 -> ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。 +> chatgpt-on-wechat(简称CoW)项目是基于大模型的智能对话机器人,支持微信公众号、企业微信应用、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/Gemini/LinkAI/ChatGLM/KIMI/文心一言/讯飞星火/通义千问/LinkAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。 最新版本支持的功能如下: -- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信,微信公众号和企业微信应用等部署方式 -- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, 文心一言, 讯飞星火 -- [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型 -- [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dall-E, stable diffusion, replicate, midjourney模型 -- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话等插件 -- [X] **Tool工具:** 与操作系统和互联网交互,支持最新信息搜索、数学计算、天气和资讯查询、网页总结,基于 [chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 实现 -- [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、领域知识库、智能客服使用,基于 [LinkAI](https://link-ai.tech/console) 实现 +- ✅ **多端部署:** 有多种部署方式可选择且功能完备,目前已支持微信公众号、企业微信应用、飞书、钉钉等部署方式 +- ✅ **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4o-mini, GPT-4o, GPT-4, Claude-3.5, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM-4,Kimi(月之暗面), MiniMax +- ✅ **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型 +- ✅ **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型 +- ✅ **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件 +- ✅ **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现 -> 欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。 +## 声明 -# 演示 +1. 本项目遵循 [MIT开源协议](/LICENSE),仅用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为 +2. 境内使用该项目时,请使用国内厂商的大模型服务,并进行必要的内容安全审核及过滤 +3. 本项目主要接入协同办公平台,推荐使用公众号、企微自建应用、钉钉、飞书等接入通道,其他通道为历史产物已不维护 +4. 任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任 -https://github.com/zhayujie/chatgpt-on-wechat/assets/26161723/d5154020-36e3-41db-8706-40ce9f3f1b1e +## 演示 -Demo made by [Visionn](https://www.wangpc.cc/) +DEMO视频:https://cdn.link-ai.tech/doc/cow_demo.mp4 -# 交流群 +## 社区 -添加小助手微信进群,请备注 "wechat": +添加小助手微信加入开源项目交流群: - + -# 更新日志 +
+ +# 企业服务 + + + +> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI应用平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署多种模式。 +> +> LinkAI 目前 已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。 + +**企业服务和产品咨询** 可联系产品顾问: + + + +
+ +# 🏷 更新日志 + +>**2024.08.02:** [1.7.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.9) 新增 讯飞4.0 模型、知识库引用来源展示、相关插件优化 + +>**2024.07.19:** [1.6.9版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.9) 新增 gpt-4o-mini 模型、阿里语音识别、企微应用渠道路由优化 + +>**2024.07.05:** [1.6.8版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.8) 和 [1.6.7版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.7),Claude3.5, Gemini 1.5 Pro, MiniMax模型、工作流图片输入、模型列表完善 + +>**2024.06.04:** [1.6.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.6) 和 [1.6.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.5),gpt-4o模型、钉钉流式卡片、讯飞语音识别/合成 + +>**2024.04.26:** [1.6.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.0),新增 Kimi 接入、gpt-4-turbo版本升级、文件总结和语音识别问题修复 + +>**2024.03.26:** [1.5.8版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.8) 和 [1.5.7版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.7),新增 GLM-4、Claude-3 模型,edge-tts 语音支持 + +>**2024.01.26:** [1.5.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.6) 和 [1.5.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.5),钉钉接入,tool插件升级,4-turbo模型更新 + +>**2023.11.11:** [1.5.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.3) 和 [1.5.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.4),新增通义千问模型、Google Gemini + +>**2023.11.10:** [1.5.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.2),新增飞书通道、图像识别对话、黑名单配置 >**2023.11.10:** [1.5.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.0),新增 `gpt-4-turbo`, `dall-e-3`, `tts` 模型接入,完善图像理解&生成、语音识别&生成的多模态能力 @@ -36,23 +72,17 @@ Demo made by [Visionn](https://www.wangpc.cc/) >**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图 ->**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。 - ->**2023.04.26:** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944)) +>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。 ->**2023.04.05:** 支持微信公众号部署,兼容插件,并支持语音图片交互,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686)) +更早更新日志查看: [归档日志](/docs/version/old-version.md) ->**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663)) +
->**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565)) +# 🚀 快速开始 ->**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385)) +快速开始详细文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start) ->**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158) - -# 快速开始 - -## 准备 +## 一、准备 ### 1. 账号注册 @@ -60,7 +90,7 @@ Demo made by [Visionn](https://www.wangpc.cc/) > 默认对话模型是 openai 的 gpt-3.5-turbo,计费方式是约每 1000tokens (约750个英文单词 或 500汉字,包含请求和回复) 消耗 $0.002,图片生成是Dell E模型,每张消耗 $0.016。 -项目同时也支持使用 LinkAI 接口,无需代理,可使用 文心、讯飞、GPT-3、GPT-4 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结和对话等能力。修改配置即可一键切换,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。 +项目同时也支持使用 LinkAI 接口,无需代理,可使用 Kimi、文心、讯飞、GPT-3.5、GPT-4o 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结、工作流等能力。修改配置即可一键使用,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。 ### 2.运行环境 @@ -76,6 +106,8 @@ git clone https://github.com/zhayujie/chatgpt-on-wechat cd chatgpt-on-wechat/ ``` +注: 如遇到网络问题可选择国内镜像 https://gitee.com/zhayujie/chatgpt-on-wechat + **(2) 安装核心依赖 (必选):** > 能够使用`itchat`创建机器人,并具有文字交流功能所需的最小依赖集合。 ```bash @@ -87,25 +119,9 @@ pip3 install -r requirements.txt ```bash pip3 install -r requirements-optional.txt ``` -> 如果某项依赖安装失败请注释掉对应的行再继续。 - -其中`tiktoken`要求`python`版本在3.8以上,它用于精确计算会话使用的tokens数量,强烈建议安装。 - +> 如果某项依赖安装失败可注释掉对应的行再继续 -使用`google`或`baidu`语音识别需安装`ffmpeg`, - -默认的`openai`语音识别不需要安装`ffmpeg`。 - -参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415) - -使用`azure`语音功能需安装依赖,并参考[文档](https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/quickstarts/setup-platform?pivots=programming-language-python&tabs=linux%2Cubuntu%2Cdotnet%2Cjre%2Cmaven%2Cnodejs%2Cmac%2Cpypi)的环境要求。 -: - -```bash -pip3 install azure-cognitiveservices-speech -``` - -## 配置 +## 二、配置 配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件: @@ -113,13 +129,13 @@ pip3 install azure-cognitiveservices-speech cp config-template.json config.json ``` -然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(请去掉注释): +然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的完整): ```bash # config.json文件内容示例 { - "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY - "model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei + "model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-4, gpt-4-turbo, wenxin, xunfei, glm-4, claude-3-haiku, moonshot + "open_ai_api_key": "YOUR API KEY", # 如果使用openAI模型则填入上面创建的 OpenAI API KEY "proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890" "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 @@ -130,15 +146,13 @@ pip3 install azure-cognitiveservices-speech "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 "speech_recognition": false, # 是否开启语音识别 "group_speech_recognition": false, # 是否开启群组语音识别 - "use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/ - "azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称 - "azure_api_version": "", # 采用Azure ChatGPT时,API版本 - "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 + "voice_reply_voice": false, # 是否使用语音回复语音 + "character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 # 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。 "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。", "use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ "linkai_api_key": "", # LinkAI Api Key - "linkai_app_code": "" # LinkAI 应用code + "linkai_app_code": "" # LinkAI 应用或工作流code } ``` **配置说明:** @@ -159,11 +173,11 @@ pip3 install azure-cognitiveservices-speech + 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图); + 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图); -+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊),但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。 ++ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊) **4.其他配置** -+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` , `claude` , `xunfei`(其中gpt-4 api暂未完全开放,申请通过后可使用) ++ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `gpt-4o-mini`, `gpt-4o`, `gpt-4`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件 + `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) + `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351) + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix ` @@ -171,7 +185,7 @@ pip3 install azure-cognitiveservices-speech + `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话) + `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。 + `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。 -+ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。 ++ `hot_reload`: 程序退出后,暂存等于状态,默认关闭。 + `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43)) + `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。 @@ -179,11 +193,11 @@ pip3 install azure-cognitiveservices-speech + `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat) + `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建 -+ `linkai_app_code`: LinkAI 应用code,选填 ++ `linkai_app_code`: LinkAI 应用或工作流的code,选填 **本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。** -## 运行 +## 三、运行 ### 1.本地运行 @@ -193,14 +207,13 @@ pip3 install azure-cognitiveservices-speech python3 app.py # windows环境下该命令通常为 python app.py ``` -终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 +终端输出二维码后,进行扫码登录,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的账号需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 ### 2.服务器部署 使用nohup命令在后台运行程序: ```bash -touch nohup.out # 首次运行需要新建日志文件 nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码 ``` 扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。 @@ -216,7 +229,7 @@ nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通 > 前提是需要安装好 `docker` 及 `docker-compose`,安装成功的表现是执行 `docker -v` 和 `docker-compose version` (或 docker compose version) 可以查看到版本号,可前往 [docker官网](https://docs.docker.com/engine/install/) 进行下载。 -#### (1) 下载 docker-compose.yml 文件 +**(1) 下载 docker-compose.yml 文件** ```bash wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml @@ -224,7 +237,7 @@ wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml 下载完成后打开 `docker-compose.yml` 修改所需配置,如 `OPEN_AI_API_KEY` 和 `GROUP_NAME_WHITE_LIST` 等。 -#### (2) 启动容器 +**(2) 启动容器** 在 `docker-compose.yml` 所在目录下执行以下命令启动容器: @@ -245,7 +258,7 @@ sudo docker compose up -d sudo docker logs -f chatgpt-on-wechat ``` -#### (3) 插件使用 +**(3) 插件使用** 如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template) 重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射: @@ -267,12 +280,22 @@ volumes: [![Deploy on Railway](https://railway.app/button.svg)](https://railway.app/template/qApznZ?referralCode=RC3znh) -## 常见问题 +
+ +# 🔎 常见问题 FAQs: -或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (beta版本,语料完善中,回复仅供参考) +或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (语料持续完善中,回复仅供参考) + +# 🛠️ 开发 + +欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py) 实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。 + +# ✉ 联系 + +欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品顾问](https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg)咨询。 -## 联系 +# 🌟 贡献者 -欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。参与更多讨论可加入技术交流群。 +![cow contributors](https://contrib.rocks/image?repo=zhayujie/chatgpt-on-wechat&max=1000) diff --git a/app.py b/app.py index 1bd6dad56..ff2a6c774 100644 --- a/app.py +++ b/app.py @@ -3,11 +3,13 @@ import os import signal import sys +import time from channel import channel_factory -from common.log import logger -from config import conf, load_config +from common import const +from config import load_config from plugins import * +import threading def sigterm_handler_wrap(_signo): @@ -23,6 +25,21 @@ def func(_signo, _stack_frame): signal.signal(_signo, func) +def start_channel(channel_name: str): + channel = channel_factory.create_channel(channel_name) + if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", + const.FEISHU, const.DINGTALK]: + PluginManager().load_plugins() + + if conf().get("use_linkai"): + try: + from common import linkai_client + threading.Thread(target=linkai_client.start, args=(channel,)).start() + except Exception as e: + pass + channel.startup() + + def run(): try: # load config @@ -40,14 +57,11 @@ def run(): if channel_name == "wxy": os.environ["WECHATY_LOG"] = "warn" - # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' - channel = channel_factory.create_channel(channel_name) - if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework"]: - PluginManager().load_plugins() + start_channel(channel_name) - # startup channel - channel.startup() + while True: + time.sleep(1) except Exception as e: logger.error("App startup failed!") logger.exception(e) diff --git a/bot/ali/ali_qwen_bot.py b/bot/ali/ali_qwen_bot.py new file mode 100644 index 000000000..ae9d7674b --- /dev/null +++ b/bot/ali/ali_qwen_bot.py @@ -0,0 +1,214 @@ +# encoding:utf-8 + +import json +import time +from typing import List, Tuple + +import openai +import openai.error +import broadscope_bailian +from broadscope_bailian import ChatQaMessage + +from bot.bot import Bot +from bot.ali.ali_qwen_session import AliQwenSession +from bot.session_manager import SessionManager +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from common import const +from config import conf, load_config + +class AliQwenBot(Bot): + def __init__(self): + super().__init__() + self.api_key_expired_time = self.set_api_key() + self.sessions = SessionManager(AliQwenSession, model=conf().get("model", const.QWEN)) + + def api_key_client(self): + return broadscope_bailian.AccessTokenClient(access_key_id=self.access_key_id(), access_key_secret=self.access_key_secret()) + + def access_key_id(self): + return conf().get("qwen_access_key_id") + + def access_key_secret(self): + return conf().get("qwen_access_key_secret") + + def agent_key(self): + return conf().get("qwen_agent_key") + + def app_id(self): + return conf().get("qwen_app_id") + + def node_id(self): + return conf().get("qwen_node_id", "") + + def temperature(self): + return conf().get("temperature", 0.2 ) + + def top_p(self): + return conf().get("top_p", 1) + + def reply(self, query, context=None): + # acquire reply content + if context.type == ContextType.TEXT: + logger.info("[QWEN] query={}".format(query)) + + session_id = context["session_id"] + reply = None + clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) + if query in clear_memory_commands: + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + elif query == "#更新配置": + load_config() + reply = Reply(ReplyType.INFO, "配置已更新") + if reply: + return reply + session = self.sessions.session_query(query, session_id) + logger.debug("[QWEN] session query={}".format(session.messages)) + + reply_content = self.reply_text(session) + logger.debug( + "[QWEN] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + session.messages, + session_id, + reply_content["content"], + reply_content["completion_tokens"], + ) + ) + if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + elif reply_content["completion_tokens"] > 0: + self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) + reply = Reply(ReplyType.TEXT, reply_content["content"]) + else: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + logger.debug("[QWEN] reply {} used 0 tokens.".format(reply_content)) + return reply + + else: + reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) + return reply + + def reply_text(self, session: AliQwenSession, retry_count=0) -> dict: + """ + call bailian's ChatCompletion to get the answer + :param session: a conversation session + :param retry_count: retry count + :return: {} + """ + try: + prompt, history = self.convert_messages_format(session.messages) + self.update_api_key_if_expired() + # NOTE 阿里百炼的call()函数未提供temperature参数,考虑到temperature和top_p参数作用相同,取两者较小的值作为top_p参数传入,详情见文档 https://help.aliyun.com/document_detail/2587502.htm + response = broadscope_bailian.Completions().call(app_id=self.app_id(), prompt=prompt, history=history, top_p=min(self.temperature(), self.top_p())) + completion_content = self.get_completion_content(response, self.node_id()) + completion_tokens, total_tokens = self.calc_tokens(session.messages, completion_content) + return { + "total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "content": completion_content, + } + except Exception as e: + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if isinstance(e, openai.error.RateLimitError): + logger.warn("[QWEN] RateLimitError: {}".format(e)) + result["content"] = "提问太快啦,请休息一下再问我吧" + if need_retry: + time.sleep(20) + elif isinstance(e, openai.error.Timeout): + logger.warn("[QWEN] Timeout: {}".format(e)) + result["content"] = "我没有收到你的消息" + if need_retry: + time.sleep(5) + elif isinstance(e, openai.error.APIError): + logger.warn("[QWEN] Bad Gateway: {}".format(e)) + result["content"] = "请再问我一次" + if need_retry: + time.sleep(10) + elif isinstance(e, openai.error.APIConnectionError): + logger.warn("[QWEN] APIConnectionError: {}".format(e)) + need_retry = False + result["content"] = "我连接不到你的网络" + else: + logger.exception("[QWEN] Exception: {}".format(e)) + need_retry = False + self.sessions.clear_session(session.session_id) + + if need_retry: + logger.warn("[QWEN] 第{}次重试".format(retry_count + 1)) + return self.reply_text(session, retry_count + 1) + else: + return result + + def set_api_key(self): + api_key, expired_time = self.api_key_client().create_token(agent_key=self.agent_key()) + broadscope_bailian.api_key = api_key + return expired_time + + def update_api_key_if_expired(self): + if time.time() > self.api_key_expired_time: + self.api_key_expired_time = self.set_api_key() + + def convert_messages_format(self, messages) -> Tuple[str, List[ChatQaMessage]]: + history = [] + user_content = '' + assistant_content = '' + system_content = '' + for message in messages: + role = message.get('role') + if role == 'user': + user_content += message.get('content') + elif role == 'assistant': + assistant_content = message.get('content') + history.append(ChatQaMessage(user_content, assistant_content)) + user_content = '' + assistant_content = '' + elif role =='system': + system_content += message.get('content') + if user_content == '': + raise Exception('no user message') + if system_content != '': + # NOTE 模拟系统消息,测试发现人格描述以"你需要扮演ChatGPT"开头能够起作用,而以"你是ChatGPT"开头模型会直接否认 + system_qa = ChatQaMessage(system_content, '好的,我会严格按照你的设定回答问题') + history.insert(0, system_qa) + logger.debug("[QWEN] converted qa messages: {}".format([item.to_dict() for item in history])) + logger.debug("[QWEN] user content as prompt: {}".format(user_content)) + return user_content, history + + def get_completion_content(self, response, node_id): + if not response['Success']: + return f"[ERROR]\n{response['Code']}:{response['Message']}" + text = response['Data']['Text'] + if node_id == '': + return text + # TODO: 当使用流程编排创建大模型应用时,响应结构如下,最终结果在['finalResult'][node_id]['response']['text']中,暂时先这么写 + # { + # 'Success': True, + # 'Code': None, + # 'Message': None, + # 'Data': { + # 'ResponseId': '9822f38dbacf4c9b8daf5ca03a2daf15', + # 'SessionId': 'session_id', + # 'Text': '{"finalResult":{"LLM_T7islK":{"params":{"modelId":"qwen-plus-v1","prompt":"${systemVars.query}${bizVars.Text}"},"response":{"text":"作为一个AI语言模型,我没有年龄,因为我没有生日。\n我只是一个程序,没有生命和身体。"}}}}', + # 'Thoughts': [], + # 'Debug': {}, + # 'DocReferences': [] + # }, + # 'RequestId': '8e11d31551ce4c3f83f49e6e0dd998b0', + # 'Failed': None + # } + text_dict = json.loads(text) + completion_content = text_dict['finalResult'][node_id]['response']['text'] + return completion_content + + def calc_tokens(self, messages, completion_content): + completion_tokens = len(completion_content) + prompt_tokens = 0 + for message in messages: + prompt_tokens += len(message["content"]) + return completion_tokens, prompt_tokens + completion_tokens diff --git a/bot/ali/ali_qwen_session.py b/bot/ali/ali_qwen_session.py new file mode 100644 index 000000000..0eb1c4a1e --- /dev/null +++ b/bot/ali/ali_qwen_session.py @@ -0,0 +1,62 @@ +from bot.session_manager import Session +from common.log import logger + +""" + e.g. + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, + {"role": "user", "content": "Where was it played?"} + ] +""" + +class AliQwenSession(Session): + def __init__(self, session_id, system_prompt=None, model="qianwen"): + super().__init__(session_id, system_prompt) + self.model = model + self.reset() + + def discard_exceeding(self, max_tokens, cur_tokens=None): + precise = True + try: + cur_tokens = self.calc_tokens() + except Exception as e: + precise = False + if cur_tokens is None: + raise e + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + while cur_tokens > max_tokens: + if len(self.messages) > 2: + self.messages.pop(1) + elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": + self.messages.pop(1) + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + break + elif len(self.messages) == 2 and self.messages[1]["role"] == "user": + logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) + break + else: + logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) + break + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + return cur_tokens + + def calc_tokens(self): + return num_tokens_from_messages(self.messages, self.model) + +def num_tokens_from_messages(messages, model): + """Returns the number of tokens used by a list of messages.""" + # 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词" + # 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html + # 目前根据字符串长度粗略估计token数,不影响正常使用 + tokens = 0 + for msg in messages: + tokens += len(msg["content"]) + return tokens diff --git a/bot/baidu/baidu_wenxin.py b/bot/baidu/baidu_wenxin.py index f35e0fa38..dd660afa1 100644 --- a/bot/baidu/baidu_wenxin.py +++ b/bot/baidu/baidu_wenxin.py @@ -1,6 +1,8 @@ # encoding:utf-8 -import requests, json +import requests +import json +from common import const from bot.bot import Bot from bot.session_manager import SessionManager from bridge.context import ContextType @@ -16,9 +18,20 @@ class BaiduWenxinBot(Bot): def __init__(self): super().__init__() - wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant" - if conf().get("model") and conf().get("model") == "wenxin-4": - wenxin_model = "completions_pro" + wenxin_model = conf().get("baidu_wenxin_model") + self.prompt_enabled = conf().get("baidu_wenxin_prompt_enabled") + if self.prompt_enabled: + self.prompt = conf().get("character_desc", "") + if self.prompt == "": + logger.warn("[BAIDU] Although you enabled model prompt, character_desc is not specified.") + if wenxin_model is not None: + wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant" + else: + if conf().get("model") and conf().get("model") == const.WEN_XIN: + wenxin_model = "completions" + elif conf().get("model") and conf().get("model") == const.WEN_XIN_4: + wenxin_model = "completions_pro" + self.sessions = SessionManager(BaiduWenxinSession, model=wenxin_model) def reply(self, query, context=None): @@ -76,7 +89,7 @@ def reply_text(self, session: BaiduWenxinSession, retry_count=0): headers = { 'Content-Type': 'application/json' } - payload = {'messages': session.messages} + payload = {'messages': session.messages, 'system': self.prompt} if self.prompt_enabled else {'messages': session.messages} response = requests.request("POST", url, headers=headers, data=json.dumps(payload)) response_text = json.loads(response.text) logger.info(f"[BAIDU] response text={response_text}") @@ -94,7 +107,7 @@ def reply_text(self, session: BaiduWenxinSession, retry_count=0): logger.warn("[BAIDU] Exception: {}".format(e)) need_retry = False self.sessions.clear_session(session.session_id) - result = {"completion_tokens": 0, "content": "出错了: {}".format(e)} + result = {"total_tokens": 0, "completion_tokens": 0, "content": "出错了: {}".format(e)} return result def get_access_token(self): diff --git a/bot/bot_factory.py b/bot/bot_factory.py index da12f952d..a6ef2415b 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -43,4 +43,30 @@ def create_bot(bot_type): elif bot_type == const.CLAUDEAI: from bot.claude.claude_ai_bot import ClaudeAIBot return ClaudeAIBot() + elif bot_type == const.CLAUDEAPI: + from bot.claudeapi.claude_api_bot import ClaudeAPIBot + return ClaudeAPIBot() + elif bot_type == const.QWEN: + from bot.ali.ali_qwen_bot import AliQwenBot + return AliQwenBot() + elif bot_type == const.QWEN_DASHSCOPE: + from bot.dashscope.dashscope_bot import DashscopeBot + return DashscopeBot() + elif bot_type == const.GEMINI: + from bot.gemini.google_gemini_bot import GoogleGeminiBot + return GoogleGeminiBot() + + elif bot_type == const.ZHIPU_AI: + from bot.zhipuai.zhipuai_bot import ZHIPUAIBot + return ZHIPUAIBot() + + elif bot_type == const.MOONSHOT: + from bot.moonshot.moonshot_bot import MoonshotBot + return MoonshotBot() + + elif bot_type == const.MiniMax: + from bot.minimax.minimax_bot import MinimaxBot + return MinimaxBot() + + raise RuntimeError diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 8c9a2504a..241b3a7be 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -148,8 +148,9 @@ def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_cou time.sleep(10) elif isinstance(e, openai.error.APIConnectionError): logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) - need_retry = False result["content"] = "我连接不到你的网络" + if need_retry: + time.sleep(5) else: logger.exception("[CHATGPT] Exception: {}".format(e)) need_retry = False @@ -170,24 +171,70 @@ def __init__(self): self.args["deployment_id"] = conf().get("azure_deployment_id") def create_img(self, query, retry_count=0, api_key=None): - api_version = "2022-08-03-preview" - url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version) - api_key = api_key or openai.api_key - headers = {"api-key": api_key, "Content-Type": "application/json"} - try: - body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")} - submission = requests.post(url, headers=headers, json=body) - operation_location = submission.headers["Operation-Location"] - retry_after = submission.headers["Retry-after"] - status = "" - image_url = "" - while status != "Succeeded": - logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds") - time.sleep(int(retry_after)) - response = requests.get(operation_location, headers=headers) - status = response.json()["status"] - image_url = response.json()["result"]["contentUrl"] - return True, image_url - except Exception as e: - logger.error("create image error: {}".format(e)) - return False, "图片生成失败" + text_to_image_model = conf().get("text_to_image") + if text_to_image_model == "dall-e-2": + api_version = "2023-06-01-preview" + endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base") + # 检查endpoint是否以/结尾 + if not endpoint.endswith("/"): + endpoint = endpoint + "/" + url = "{}openai/images/generations:submit?api-version={}".format(endpoint, api_version) + api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key") + headers = {"api-key": api_key, "Content-Type": "application/json"} + try: + body = {"prompt": query, "size": conf().get("image_create_size", "256x256"),"n": 1} + submission = requests.post(url, headers=headers, json=body) + operation_location = submission.headers['operation-location'] + status = "" + while (status != "succeeded"): + if retry_count > 3: + return False, "图片生成失败" + response = requests.get(operation_location, headers=headers) + status = response.json()['status'] + retry_count += 1 + image_url = response.json()['result']['data'][0]['url'] + return True, image_url + except Exception as e: + logger.error("create image error: {}".format(e)) + return False, "图片生成失败" + elif text_to_image_model == "dall-e-3": + api_version = conf().get("azure_api_version", "2024-02-15-preview") + endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base") + # 检查endpoint是否以/结尾 + if not endpoint.endswith("/"): + endpoint = endpoint + "/" + url = "{}openai/deployments/{}/images/generations?api-version={}".format(endpoint, conf().get("azure_openai_dalle_deployment_id","text_to_image"),api_version) + api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key") + headers = {"api-key": api_key, "Content-Type": "application/json"} + try: + body = {"prompt": query, "size": conf().get("image_create_size", "1024x1024"), "quality": conf().get("dalle3_image_quality", "standard")} + response = requests.post(url, headers=headers, json=body) + response.raise_for_status() # 检查请求是否成功 + data = response.json() + + # 检查响应中是否包含图像 URL + if 'data' in data and len(data['data']) > 0 and 'url' in data['data'][0]: + image_url = data['data'][0]['url'] + return True, image_url + else: + error_message = "响应中没有图像 URL" + logger.error(error_message) + return False, "图片生成失败" + + except requests.exceptions.RequestException as e: + # 捕获所有请求相关的异常 + try: + error_detail = response.json().get('error', {}).get('message', str(e)) + except ValueError: + error_detail = str(e) + error_message = f"{error_detail}" + logger.error(error_message) + return False, error_message + + except Exception as e: + # 捕获所有其他异常 + error_message = f"生成图像时发生错误: {e}" + logger.error(error_message) + return False, "图片生成失败" + else: + return False, "图片生成失败,未配置text_to_image参数" diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py index e7dabecfd..aa34ba316 100644 --- a/bot/chatgpt/chat_gpt_session.py +++ b/bot/chatgpt/chat_gpt_session.py @@ -57,17 +57,20 @@ def calc_tokens(self): def num_tokens_from_messages(messages, model): """Returns the number of tokens used by a list of messages.""" - if model in ["wenxin", "xunfei"]: + if model in ["wenxin", "xunfei", const.GEMINI]: return num_tokens_by_character(messages) import tiktoken - if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]: + if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot", const.LINKAI_35]: return num_tokens_from_messages(messages, model="gpt-3.5-turbo") elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]: + "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview", + "gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW, const.GPT4_TURBO_01_25, + const.GPT_4o, const.GPT_4O_0806, const.GPT_4o_MINI, const.LINKAI_4o, const.LINKAI_4_TURBO]: return num_tokens_from_messages(messages, model="gpt-4") - + elif model.startswith("claude-3"): + return num_tokens_from_messages(messages, model="gpt-3.5-turbo") try: encoding = tiktoken.encoding_for_model(model) except KeyError: diff --git a/bot/claudeapi/claude_api_bot.py b/bot/claudeapi/claude_api_bot.py new file mode 100644 index 000000000..e5452e5c0 --- /dev/null +++ b/bot/claudeapi/claude_api_bot.py @@ -0,0 +1,135 @@ +# encoding:utf-8 + +import time + +import openai +import openai.error +import anthropic + +from bot.bot import Bot +from bot.openai.open_ai_image import OpenAIImage +from bot.chatgpt.chat_gpt_session import ChatGPTSession +from bot.gemini.google_gemini_bot import GoogleGeminiBot +from bot.session_manager import SessionManager +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf + +user_session = dict() + + +# OpenAI对话模型API (可用) +class ClaudeAPIBot(Bot, OpenAIImage): + def __init__(self): + super().__init__() + self.claudeClient = anthropic.Anthropic( + api_key=conf().get("claude_api_key") + ) + openai.api_key = conf().get("open_ai_api_key") + if conf().get("open_ai_api_base"): + openai.api_base = conf().get("open_ai_api_base") + proxy = conf().get("proxy") + if proxy: + openai.proxy = proxy + + self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "text-davinci-003") + + def reply(self, query, context=None): + # acquire reply content + if context and context.type: + if context.type == ContextType.TEXT: + logger.info("[CLAUDE_API] query={}".format(query)) + session_id = context["session_id"] + reply = None + if query == "#清除记忆": + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + else: + session = self.sessions.session_query(query, session_id) + result = self.reply_text(session) + logger.info(result) + total_tokens, completion_tokens, reply_content = ( + result["total_tokens"], + result["completion_tokens"], + result["content"], + ) + logger.debug( + "[CLAUDE_API] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens) + ) + + if total_tokens == 0: + reply = Reply(ReplyType.ERROR, reply_content) + else: + self.sessions.session_reply(reply_content, session_id, total_tokens) + reply = Reply(ReplyType.TEXT, reply_content) + return reply + elif context.type == ContextType.IMAGE_CREATE: + ok, retstring = self.create_img(query, 0) + reply = None + if ok: + reply = Reply(ReplyType.IMAGE_URL, retstring) + else: + reply = Reply(ReplyType.ERROR, retstring) + return reply + + def reply_text(self, session: ChatGPTSession, retry_count=0): + try: + actual_model = self._model_mapping(conf().get("model")) + response = self.claudeClient.messages.create( + model=actual_model, + max_tokens=1024, + # system=conf().get("system"), + messages=GoogleGeminiBot.filter_messages(session.messages) + ) + # response = openai.Completion.create(prompt=str(session), **self.args) + res_content = response.content[0].text.strip().replace("<|endoftext|>", "") + total_tokens = response.usage.input_tokens+response.usage.output_tokens + completion_tokens = response.usage.output_tokens + logger.info("[CLAUDE_API] reply={}".format(res_content)) + return { + "total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "content": res_content, + } + except Exception as e: + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if isinstance(e, openai.error.RateLimitError): + logger.warn("[CLAUDE_API] RateLimitError: {}".format(e)) + result["content"] = "提问太快啦,请休息一下再问我吧" + if need_retry: + time.sleep(20) + elif isinstance(e, openai.error.Timeout): + logger.warn("[CLAUDE_API] Timeout: {}".format(e)) + result["content"] = "我没有收到你的消息" + if need_retry: + time.sleep(5) + elif isinstance(e, openai.error.APIConnectionError): + logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e)) + need_retry = False + result["content"] = "我连接不到你的网络" + else: + logger.warn("[CLAUDE_API] Exception: {}".format(e)) + need_retry = False + self.sessions.clear_session(session.session_id) + + if need_retry: + logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1)) + return self.reply_text(session, retry_count + 1) + else: + return result + + def _model_mapping(self, model) -> str: + if model == "claude-3-opus": + return "claude-3-opus-20240229" + elif model == "claude-3-sonnet": + return "claude-3-sonnet-20240229" + elif model == "claude-3-haiku": + return "claude-3-haiku-20240307" + elif model == "claude-3.5-sonnet": + return "claude-3-5-sonnet-20240620" + return model diff --git a/bot/dashscope/dashscope_bot.py b/bot/dashscope/dashscope_bot.py new file mode 100644 index 000000000..07554c4d2 --- /dev/null +++ b/bot/dashscope/dashscope_bot.py @@ -0,0 +1,117 @@ +# encoding:utf-8 + +from bot.bot import Bot +from bot.session_manager import SessionManager +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf, load_config +from .dashscope_session import DashscopeSession +import os +import dashscope +from http import HTTPStatus + + + +dashscope_models = { + "qwen-turbo": dashscope.Generation.Models.qwen_turbo, + "qwen-plus": dashscope.Generation.Models.qwen_plus, + "qwen-max": dashscope.Generation.Models.qwen_max, + "qwen-bailian-v1": dashscope.Generation.Models.bailian_v1 +} +# ZhipuAI对话模型API +class DashscopeBot(Bot): + def __init__(self): + super().__init__() + self.sessions = SessionManager(DashscopeSession, model=conf().get("model") or "qwen-plus") + self.model_name = conf().get("model") or "qwen-plus" + self.api_key = conf().get("dashscope_api_key") + os.environ["DASHSCOPE_API_KEY"] = self.api_key + self.client = dashscope.Generation + + def reply(self, query, context=None): + # acquire reply content + if context.type == ContextType.TEXT: + logger.info("[DASHSCOPE] query={}".format(query)) + + session_id = context["session_id"] + reply = None + clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) + if query in clear_memory_commands: + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + elif query == "#更新配置": + load_config() + reply = Reply(ReplyType.INFO, "配置已更新") + if reply: + return reply + session = self.sessions.session_query(query, session_id) + logger.debug("[DASHSCOPE] session query={}".format(session.messages)) + + reply_content = self.reply_text(session) + logger.debug( + "[DASHSCOPE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + session.messages, + session_id, + reply_content["content"], + reply_content["completion_tokens"], + ) + ) + if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + elif reply_content["completion_tokens"] > 0: + self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) + reply = Reply(ReplyType.TEXT, reply_content["content"]) + else: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + logger.debug("[DASHSCOPE] reply {} used 0 tokens.".format(reply_content)) + return reply + else: + reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) + return reply + + def reply_text(self, session: DashscopeSession, retry_count=0) -> dict: + """ + call openai's ChatCompletion to get the answer + :param session: a conversation session + :param session_id: session id + :param retry_count: retry count + :return: {} + """ + try: + dashscope.api_key = self.api_key + response = self.client.call( + dashscope_models[self.model_name], + messages=session.messages, + result_format="message" + ) + if response.status_code == HTTPStatus.OK: + content = response.output.choices[0]["message"]["content"] + return { + "total_tokens": response.usage["total_tokens"], + "completion_tokens": response.usage["output_tokens"], + "content": content, + } + else: + logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( + response.request_id, response.status_code, + response.code, response.message + )) + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if need_retry: + return self.reply_text(session, retry_count + 1) + else: + return result + except Exception as e: + logger.exception(e) + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if need_retry: + return self.reply_text(session, retry_count + 1) + else: + return result diff --git a/bot/dashscope/dashscope_session.py b/bot/dashscope/dashscope_session.py new file mode 100644 index 000000000..0de57b926 --- /dev/null +++ b/bot/dashscope/dashscope_session.py @@ -0,0 +1,51 @@ +from bot.session_manager import Session +from common.log import logger + + +class DashscopeSession(Session): + def __init__(self, session_id, system_prompt=None, model="qwen-turbo"): + super().__init__(session_id) + self.reset() + + def discard_exceeding(self, max_tokens, cur_tokens=None): + precise = True + try: + cur_tokens = self.calc_tokens() + except Exception as e: + precise = False + if cur_tokens is None: + raise e + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + while cur_tokens > max_tokens: + if len(self.messages) > 2: + self.messages.pop(1) + elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": + self.messages.pop(1) + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + break + elif len(self.messages) == 2 and self.messages[1]["role"] == "user": + logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) + break + else: + logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, + len(self.messages))) + break + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + return cur_tokens + + def calc_tokens(self): + return num_tokens_from_messages(self.messages) + + +def num_tokens_from_messages(messages): + # 只是大概,具体计算规则:https://help.aliyun.com/zh/dashscope/developer-reference/token-api?spm=a2c4g.11186623.0.0.4d8b12b0BkP3K9 + tokens = 0 + for msg in messages: + tokens += len(msg["content"]) + return tokens diff --git a/bot/gemini/google_gemini_bot.py b/bot/gemini/google_gemini_bot.py new file mode 100644 index 000000000..8a4100ae2 --- /dev/null +++ b/bot/gemini/google_gemini_bot.py @@ -0,0 +1,81 @@ +""" +Google gemini bot + +@author zhayujie +@Date 2023/12/15 +""" +# encoding:utf-8 + +from bot.bot import Bot +import google.generativeai as genai +from bot.session_manager import SessionManager +from bridge.context import ContextType, Context +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf +from bot.baidu.baidu_wenxin_session import BaiduWenxinSession + + +# OpenAI对话模型API (可用) +class GoogleGeminiBot(Bot): + + def __init__(self): + super().__init__() + self.api_key = conf().get("gemini_api_key") + # 复用文心的token计算方式 + self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo") + self.model = conf().get("model") or "gemini-pro" + if self.model == "gemini": + self.model = "gemini-pro" + def reply(self, query, context: Context = None) -> Reply: + try: + if context.type != ContextType.TEXT: + logger.warn(f"[Gemini] Unsupported message type, type={context.type}") + return Reply(ReplyType.TEXT, None) + logger.info(f"[Gemini] query={query}") + session_id = context["session_id"] + session = self.sessions.session_query(query, session_id) + gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages)) + genai.configure(api_key=self.api_key) + model = genai.GenerativeModel(self.model) + response = model.generate_content(gemini_messages) + reply_text = response.text + self.sessions.session_reply(reply_text, session_id) + logger.info(f"[Gemini] reply={reply_text}") + return Reply(ReplyType.TEXT, reply_text) + except Exception as e: + logger.error("[Gemini] fetch reply error, may contain unsafe content") + logger.error(e) + return Reply(ReplyType.ERROR, "invoke [Gemini] api failed!") + + def _convert_to_gemini_messages(self, messages: list): + res = [] + for msg in messages: + if msg.get("role") == "user": + role = "user" + elif msg.get("role") == "assistant": + role = "model" + else: + continue + res.append({ + "role": role, + "parts": [{"text": msg.get("content")}] + }) + return res + + @staticmethod + def filter_messages(messages: list): + res = [] + turn = "user" + if not messages: + return res + for i in range(len(messages) - 1, -1, -1): + message = messages[i] + if message.get("role") != turn: + continue + res.insert(0, message) + if turn == "user": + turn = "assistant" + elif turn == "assistant": + turn = "user" + return res diff --git a/bot/linkai/link_ai_bot.py b/bot/linkai/link_ai_bot.py index 3788c6bd4..95c514dae 100644 --- a/bot/linkai/link_ai_bot.py +++ b/bot/linkai/link_ai_bot.py @@ -1,19 +1,21 @@ # access LinkAI knowledge base platform # docs: https://link-ai.tech/platform/link-app/wechat +import re import time - import requests - +import config from bot.bot import Bot from bot.chatgpt.chat_gpt_session import ChatGPTSession -from bot.openai.open_ai_image import OpenAIImage from bot.session_manager import SessionManager from bridge.context import Context, ContextType from bridge.reply import Reply, ReplyType from common.log import logger from config import conf, pconf - +import threading +from common import memory, utils +import base64 +import os class LinkAIBot(Bot): # authentication failed @@ -22,13 +24,16 @@ class LinkAIBot(Bot): def __init__(self): super().__init__() - self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") + self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo") self.args = {} def reply(self, query, context: Context = None) -> Reply: if context.type == ContextType.TEXT: return self._chat(query, context) elif context.type == ContextType.IMAGE_CREATE: + if not conf().get("text_to_image"): + logger.warn("[LinkAI] text_to_image is not enabled, ignore the IMAGE_CREATE request") + return Reply(ReplyType.TEXT, "") ok, res = self.create_img(query, 0) if ok: reply = Reply(ReplyType.IMAGE_URL, res) @@ -47,10 +52,10 @@ def _chat(self, query, context, retry_count=0) -> Reply: :param retry_count: 当前递归重试次数 :return: 回复 """ - if retry_count >= 2: + if retry_count > 2: # exit from retry 2 times logger.warn("[LINKAI] failed after maximum number of retry times") - return Reply(ReplyType.ERROR, "请再问我一次吧") + return Reply(ReplyType.TEXT, "请再问我一次吧") try: # load config @@ -58,35 +63,66 @@ def _chat(self, query, context, retry_count=0) -> Reply: logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context") app_code = None else: - app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code") + plugin_app_code = self._find_group_mapping_code(context) + app_code = context.kwargs.get("app_code") or plugin_app_code or conf().get("linkai_app_code") linkai_api_key = conf().get("linkai_api_key") session_id = context["session_id"] + session_message = self.sessions.session_msg_query(query, session_id) + logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}") + + # image process + img_cache = memory.USER_IMAGE_CACHE.get(session_id) + if img_cache: + messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache) + if messages: + session_message = messages - session = self.sessions.session_query(query, session_id) - model = conf().get("model") or "gpt-3.5-turbo" + model = conf().get("model") # remove system message - if session.messages[0].get("role") == "system": + if session_message[0].get("role") == "system": if app_code or model == "wenxin": - session.messages.pop(0) - + session_message.pop(0) body = { "app_code": app_code, - "messages": session.messages, + "messages": session_message, "model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei "temperature": conf().get("temperature"), "top_p": conf().get("top_p", 1), "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "session_id": session_id, + "sender_id": session_id, + "channel_type": conf().get("channel_type", "wx") } + try: + from linkai import LinkAIClient + client_id = LinkAIClient.fetch_client_id() + if client_id: + body["client_id"] = client_id + # start: client info deliver + if context.kwargs.get("msg"): + body["session_id"] = context.kwargs.get("msg").from_user_id + if context.kwargs.get("msg").is_group: + body["is_group"] = True + body["group_name"] = context.kwargs.get("msg").from_user_nickname + body["sender_name"] = context.kwargs.get("msg").actual_user_nickname + else: + if body.get("channel_type") in ["wechatcom_app"]: + body["sender_name"] = context.kwargs.get("msg").from_user_id + else: + body["sender_name"] = context.kwargs.get("msg").from_user_nickname + + except Exception as e: + pass file_id = context.kwargs.get("file_id") if file_id: body["file_id"] = file_id - logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}, file_id={file_id}") + logger.info(f"[LINKAI] query={query}, app_code={app_code}, model={body.get('model')}, file_id={file_id}") headers = {"Authorization": "Bearer " + linkai_api_key} # do http request - base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + base_url = conf().get("linkai_api_base", "https://api.link-ai.tech") res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers, timeout=conf().get("request_timeout", 180)) if res.status_code == 200: @@ -94,9 +130,12 @@ def _chat(self, query, context, retry_count=0) -> Reply: response = res.json() reply_content = response["choices"][0]["message"]["content"] total_tokens = response["usage"]["total_tokens"] - logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}") - self.sessions.session_reply(reply_content, session_id, total_tokens) - + res_code = response.get('code') + logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}, res_code={res_code}") + if res_code == 429: + logger.warn(f"[LINKAI] 用户访问超出限流配置,sender_id={body.get('sender_id')}") + else: + self.sessions.session_reply(reply_content, session_id, total_tokens, query=query) agent_suffix = self._fetch_agent_suffix(response) if agent_suffix: reply_content += agent_suffix @@ -104,6 +143,13 @@ def _chat(self, query, context, retry_count=0) -> Reply: knowledge_suffix = self._fetch_knowledge_search_suffix(response) if knowledge_suffix: reply_content += knowledge_suffix + # image process + if response["choices"][0].get("img_urls"): + thread = threading.Thread(target=self._send_image, args=(context.get("channel"), context, response["choices"][0].get("img_urls"))) + thread.start() + if response["choices"][0].get("text_content"): + reply_content = response["choices"][0].get("text_content") + reply_content = self._process_url(reply_content) return Reply(ReplyType.TEXT, reply_content) else: @@ -118,7 +164,10 @@ def _chat(self, query, context, retry_count=0) -> Reply: logger.warn(f"[LINKAI] do retry, times={retry_count}") return self._chat(query, context, retry_count + 1) - return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧") + error_reply = "提问太快啦,请休息一下再问我吧" + if res.status_code == 409: + error_reply = "这个问题我还没有学会,请问我其它问题吧" + return Reply(ReplyType.TEXT, error_reply) except Exception as e: logger.exception(e) @@ -127,6 +176,66 @@ def _chat(self, query, context, retry_count=0) -> Reply: logger.warn(f"[LINKAI] do retry, times={retry_count}") return self._chat(query, context, retry_count + 1) + def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict): + try: + enable_image_input = False + app_info = self._fetch_app_info(app_code) + if not app_info: + logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}") + return None + plugins = app_info.get("data").get("plugins") + for plugin in plugins: + if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"): + enable_image_input = True + if not enable_image_input: + return + msg = img_cache.get("msg") + path = img_cache.get("path") + msg.prepare() + logger.info(f"[LinkAI] query with images, path={path}") + messages = self._build_vision_msg(query, path) + memory.USER_IMAGE_CACHE[session_id] = None + return messages + except Exception as e: + logger.exception(e) + + def _find_group_mapping_code(self, context): + try: + if context.kwargs.get("isgroup"): + group_name = context.kwargs.get("msg").from_user_nickname + if config.plugin_config and config.plugin_config.get("linkai"): + linkai_config = config.plugin_config.get("linkai") + group_mapping = linkai_config.get("group_app_map") + if group_mapping and group_name: + return group_mapping.get(group_name) + except Exception as e: + logger.exception(e) + return None + + def _build_vision_msg(self, query: str, path: str): + try: + suffix = utils.get_path_suffix(path) + with open(path, "rb") as file: + base64_str = base64.b64encode(file.read()).decode('utf-8') + messages = [{ + "role": "user", + "content": [ + { + "type": "text", + "text": query + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/{suffix};base64,{base64_str}" + } + } + ] + }] + return messages + except Exception as e: + logger.exception(e) + def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict: if retry_count >= 2: # exit from retry 2 times @@ -149,10 +258,10 @@ def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dic } if self.args.get("max_tokens"): body["max_tokens"] = self.args.get("max_tokens") - headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} + headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} # do http request - base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + base_url = conf().get("linkai_api_base", "https://api.link-ai.tech") res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers, timeout=conf().get("request_timeout", 180)) if res.status_code == 200: @@ -192,6 +301,16 @@ def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dic logger.warn(f"[LINKAI] do retry, times={retry_count}") return self.reply_text(session, app_code, retry_count + 1) + def _fetch_app_info(self, app_code: str): + headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} + # do http request + base_url = conf().get("linkai_api_base", "https://api.link-ai.tech") + params = {"app_code": app_code} + res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10)) + if res.status_code == 200: + return res.json() + else: + logger.warning(f"[LinkAI] find app info exception, res={res}") def create_img(self, query, retry_count=0, api_key=None): try: @@ -207,7 +326,7 @@ def create_img(self, query, retry_count=0, api_key=None): "response_format": "url", "img_proxy": conf().get("image_proxy") } - url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/images/generations" + url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/images/generations" res = requests.post(url, headers=headers, json=data, timeout=(5, 90)) t2 = time.time() image_url = res.json()["data"][0]["url"] @@ -236,6 +355,7 @@ def _fetch_knowledge_search_suffix(self, response) -> str: except Exception as e: logger.exception(e) + def _fetch_agent_suffix(self, response): try: plugin_list = [] @@ -252,7 +372,9 @@ def _fetch_agent_suffix(self, response): suffix += f"{turn.get('thought')}\n" if plugin_name: plugin_list.append(turn.get('plugin_name')) - suffix += f"{turn.get('plugin_icon')} {turn.get('plugin_name')}" + if turn.get('plugin_icon'): + suffix += f"{turn.get('plugin_icon')} " + suffix += f"{turn.get('plugin_name')}" if turn.get('plugin_input'): suffix += f":{turn.get('plugin_input')}" if i < len(chain) - 1: @@ -262,3 +384,92 @@ def _fetch_agent_suffix(self, response): return suffix except Exception as e: logger.exception(e) + + def _process_url(self, text): + try: + url_pattern = re.compile(r'\[(.*?)\]\((http[s]?://.*?)\)') + def replace_markdown_url(match): + return f"{match.group(2)}" + return url_pattern.sub(replace_markdown_url, text) + except Exception as e: + logger.error(e) + + def _send_image(self, channel, context, image_urls): + if not image_urls: + return + max_send_num = conf().get("max_media_send_count") + send_interval = conf().get("media_send_interval") + file_type = (".pdf", ".doc", ".docx", ".csv", ".xls", ".xlsx", ".txt", ".rtf", ".ppt", ".pptx") + try: + i = 0 + for url in image_urls: + if max_send_num and i >= max_send_num: + continue + i += 1 + if url.endswith(".mp4"): + reply_type = ReplyType.VIDEO_URL + elif url.endswith(file_type): + reply_type = ReplyType.FILE + url = _download_file(url) + if not url: + continue + else: + reply_type = ReplyType.IMAGE_URL + reply = Reply(reply_type, url) + channel.send(reply, context) + if send_interval: + time.sleep(send_interval) + except Exception as e: + logger.error(e) + + +def _download_file(url: str): + try: + file_path = "tmp" + if not os.path.exists(file_path): + os.makedirs(file_path) + file_name = url.split("/")[-1] # 获取文件名 + file_path = os.path.join(file_path, file_name) + response = requests.get(url) + with open(file_path, "wb") as f: + f.write(response.content) + return file_path + except Exception as e: + logger.warn(e) + + +class LinkAISessionManager(SessionManager): + def session_msg_query(self, query, session_id): + session = self.build_session(session_id) + messages = session.messages + [{"role": "user", "content": query}] + return messages + + def session_reply(self, reply, session_id, total_tokens=None, query=None): + session = self.build_session(session_id) + if query: + session.add_query(query) + session.add_reply(reply) + try: + max_tokens = conf().get("conversation_max_tokens", 2500) + tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) + logger.debug(f"[LinkAI] chat history, before tokens={total_tokens}, now tokens={tokens_cnt}") + except Exception as e: + logger.warning("Exception when counting tokens precisely for session: {}".format(str(e))) + return session + + +class LinkAISession(ChatGPTSession): + def calc_tokens(self): + if not self.messages: + return 0 + return len(str(self.messages)) + + def discard_exceeding(self, max_tokens, cur_tokens=None): + cur_tokens = self.calc_tokens() + if cur_tokens > max_tokens: + for i in range(0, len(self.messages)): + if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user": + self.messages.pop(i) + self.messages.pop(i - 1) + return self.calc_tokens() + return cur_tokens diff --git a/bot/minimax/minimax_bot.py b/bot/minimax/minimax_bot.py new file mode 100644 index 000000000..40112d8ee --- /dev/null +++ b/bot/minimax/minimax_bot.py @@ -0,0 +1,151 @@ +# encoding:utf-8 + +import time + +import openai +import openai.error +from bot.bot import Bot +from bot.minimax.minimax_session import MinimaxSession +from bot.session_manager import SessionManager +from bridge.context import Context, ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf, load_config +from bot.chatgpt.chat_gpt_session import ChatGPTSession +import requests +from common import const + + +# ZhipuAI对话模型API +class MinimaxBot(Bot): + def __init__(self): + super().__init__() + self.args = { + "model": conf().get("model") or "abab6.5", # 对话模型的名称 + "temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。 + "top_p": conf().get("top_p", 0.95), # 使用默认值 + } + self.api_key = conf().get("Minimax_api_key") + self.group_id = conf().get("Minimax_group_id") + self.base_url = conf().get("Minimax_base_url", f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={self.group_id}") + # tokens_to_generate/bot_setting/reply_constraints可自行修改 + self.request_body = { + "model": self.args["model"], + "tokens_to_generate": 2048, + "reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"}, + "messages": [], + "bot_setting": [ + { + "bot_name": "MM智能助理", + "content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。", + } + ], + } + self.sessions = SessionManager(MinimaxSession, model=const.MiniMax) + + def reply(self, query, context: Context = None) -> Reply: + # acquire reply content + logger.info("[Minimax_AI] query={}".format(query)) + if context.type == ContextType.TEXT: + session_id = context["session_id"] + reply = None + clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) + if query in clear_memory_commands: + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + elif query == "#更新配置": + load_config() + reply = Reply(ReplyType.INFO, "配置已更新") + if reply: + return reply + session = self.sessions.session_query(query, session_id) + logger.debug("[Minimax_AI] session query={}".format(session)) + + model = context.get("Minimax_model") + new_args = self.args.copy() + if model: + new_args["model"] = model + # if context.get('stream'): + # # reply in stream + # return self.reply_text_stream(query, new_query, session_id) + + reply_content = self.reply_text(session, args=new_args) + logger.debug( + "[Minimax_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + session.messages, + session_id, + reply_content["content"], + reply_content["completion_tokens"], + ) + ) + if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + elif reply_content["completion_tokens"] > 0: + self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) + reply = Reply(ReplyType.TEXT, reply_content["content"]) + else: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + logger.debug("[Minimax_AI] reply {} used 0 tokens.".format(reply_content)) + return reply + else: + reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) + return reply + + def reply_text(self, session: MinimaxSession, args=None, retry_count=0) -> dict: + """ + call openai's ChatCompletion to get the answer + :param session: a conversation session + :param session_id: session id + :param retry_count: retry count + :return: {} + """ + try: + headers = {"Content-Type": "application/json", "Authorization": "Bearer " + self.api_key} + self.request_body["messages"].extend(session.messages) + logger.info("[Minimax_AI] request_body={}".format(self.request_body)) + # logger.info("[Minimax_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) + res = requests.post(self.base_url, headers=headers, json=self.request_body) + + # self.request_body["messages"].extend(response.json()["choices"][0]["messages"]) + if res.status_code == 200: + response = res.json() + return { + "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response["usage"]["total_tokens"], + "content": response["reply"], + } + else: + response = res.json() + error = response.get("error") + logger.error(f"[Minimax_AI] chat failed, status_code={res.status_code}, " f"msg={error.get('message')}, type={error.get('type')}") + + result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"} + need_retry = False + if res.status_code >= 500: + # server error, need retry + logger.warn(f"[Minimax_AI] do retry, times={retry_count}") + need_retry = retry_count < 2 + elif res.status_code == 401: + result["content"] = "授权失败,请检查API Key是否正确" + elif res.status_code == 429: + result["content"] = "请求过于频繁,请稍后再试" + need_retry = retry_count < 2 + else: + need_retry = False + + if need_retry: + time.sleep(3) + return self.reply_text(session, args, retry_count + 1) + else: + return result + except Exception as e: + logger.exception(e) + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if need_retry: + return self.reply_text(session, args, retry_count + 1) + else: + return result diff --git a/bot/minimax/minimax_session.py b/bot/minimax/minimax_session.py new file mode 100644 index 000000000..1925b4bfe --- /dev/null +++ b/bot/minimax/minimax_session.py @@ -0,0 +1,72 @@ +from bot.session_manager import Session +from common.log import logger + +""" + e.g. + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, + {"role": "user", "content": "Where was it played?"} + ] +""" + + +class MinimaxSession(Session): + def __init__(self, session_id, system_prompt=None, model="minimax"): + super().__init__(session_id, system_prompt) + self.model = model + # self.reset() + + def add_query(self, query): + user_item = {"sender_type": "USER", "sender_name": self.session_id, "text": query} + self.messages.append(user_item) + + def add_reply(self, reply): + assistant_item = {"sender_type": "BOT", "sender_name": "MM智能助理", "text": reply} + self.messages.append(assistant_item) + + def discard_exceeding(self, max_tokens, cur_tokens=None): + precise = True + try: + cur_tokens = self.calc_tokens() + except Exception as e: + precise = False + if cur_tokens is None: + raise e + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + while cur_tokens > max_tokens: + if len(self.messages) > 2: + self.messages.pop(1) + elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "BOT": + self.messages.pop(1) + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + break + elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "USER": + logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) + break + else: + logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) + break + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + return cur_tokens + + def calc_tokens(self): + return num_tokens_from_messages(self.messages, self.model) + + +def num_tokens_from_messages(messages, model): + """Returns the number of tokens used by a list of messages.""" + # 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词" + # 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html + # 目前根据字符串长度粗略估计token数,不影响正常使用 + tokens = 0 + for msg in messages: + tokens += len(msg["text"]) + return tokens diff --git a/bot/moonshot/moonshot_bot.py b/bot/moonshot/moonshot_bot.py new file mode 100644 index 000000000..7d2589cda --- /dev/null +++ b/bot/moonshot/moonshot_bot.py @@ -0,0 +1,143 @@ +# encoding:utf-8 + +import time + +import openai +import openai.error +from bot.bot import Bot +from bot.session_manager import SessionManager +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf, load_config +from .moonshot_session import MoonshotSession +import requests + + +# ZhipuAI对话模型API +class MoonshotBot(Bot): + def __init__(self): + super().__init__() + self.sessions = SessionManager(MoonshotSession, model=conf().get("model") or "moonshot-v1-128k") + self.args = { + "model": conf().get("model") or "moonshot-v1-128k", # 对话模型的名称 + "temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。 + "top_p": conf().get("top_p", 1.0), # 使用默认值 + } + self.api_key = conf().get("moonshot_api_key") + self.base_url = conf().get("moonshot_base_url", "https://api.moonshot.cn/v1/chat/completions") + + def reply(self, query, context=None): + # acquire reply content + if context.type == ContextType.TEXT: + logger.info("[MOONSHOT_AI] query={}".format(query)) + + session_id = context["session_id"] + reply = None + clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) + if query in clear_memory_commands: + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + elif query == "#更新配置": + load_config() + reply = Reply(ReplyType.INFO, "配置已更新") + if reply: + return reply + session = self.sessions.session_query(query, session_id) + logger.debug("[MOONSHOT_AI] session query={}".format(session.messages)) + + model = context.get("moonshot_model") + new_args = self.args.copy() + if model: + new_args["model"] = model + # if context.get('stream'): + # # reply in stream + # return self.reply_text_stream(query, new_query, session_id) + + reply_content = self.reply_text(session, args=new_args) + logger.debug( + "[MOONSHOT_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + session.messages, + session_id, + reply_content["content"], + reply_content["completion_tokens"], + ) + ) + if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + elif reply_content["completion_tokens"] > 0: + self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) + reply = Reply(ReplyType.TEXT, reply_content["content"]) + else: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + logger.debug("[MOONSHOT_AI] reply {} used 0 tokens.".format(reply_content)) + return reply + else: + reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) + return reply + + def reply_text(self, session: MoonshotSession, args=None, retry_count=0) -> dict: + """ + call openai's ChatCompletion to get the answer + :param session: a conversation session + :param session_id: session id + :param retry_count: retry count + :return: {} + """ + try: + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key + } + body = args + body["messages"] = session.messages + # logger.debug("[MOONSHOT_AI] response={}".format(response)) + # logger.info("[MOONSHOT_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) + res = requests.post( + self.base_url, + headers=headers, + json=body + ) + if res.status_code == 200: + response = res.json() + return { + "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response["usage"]["completion_tokens"], + "content": response["choices"][0]["message"]["content"] + } + else: + response = res.json() + error = response.get("error") + logger.error(f"[MOONSHOT_AI] chat failed, status_code={res.status_code}, " + f"msg={error.get('message')}, type={error.get('type')}") + + result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"} + need_retry = False + if res.status_code >= 500: + # server error, need retry + logger.warn(f"[MOONSHOT_AI] do retry, times={retry_count}") + need_retry = retry_count < 2 + elif res.status_code == 401: + result["content"] = "授权失败,请检查API Key是否正确" + elif res.status_code == 429: + result["content"] = "请求过于频繁,请稍后再试" + need_retry = retry_count < 2 + else: + need_retry = False + + if need_retry: + time.sleep(3) + return self.reply_text(session, args, retry_count + 1) + else: + return result + except Exception as e: + logger.exception(e) + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if need_retry: + return self.reply_text(session, args, retry_count + 1) + else: + return result diff --git a/bot/moonshot/moonshot_session.py b/bot/moonshot/moonshot_session.py new file mode 100644 index 000000000..63e08f584 --- /dev/null +++ b/bot/moonshot/moonshot_session.py @@ -0,0 +1,51 @@ +from bot.session_manager import Session +from common.log import logger + + +class MoonshotSession(Session): + def __init__(self, session_id, system_prompt=None, model="moonshot-v1-128k"): + super().__init__(session_id, system_prompt) + self.model = model + self.reset() + + def discard_exceeding(self, max_tokens, cur_tokens=None): + precise = True + try: + cur_tokens = self.calc_tokens() + except Exception as e: + precise = False + if cur_tokens is None: + raise e + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + while cur_tokens > max_tokens: + if len(self.messages) > 2: + self.messages.pop(1) + elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": + self.messages.pop(1) + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + break + elif len(self.messages) == 2 and self.messages[1]["role"] == "user": + logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) + break + else: + logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, + len(self.messages))) + break + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + return cur_tokens + + def calc_tokens(self): + return num_tokens_from_messages(self.messages, self.model) + + +def num_tokens_from_messages(messages, model): + tokens = 0 + for msg in messages: + tokens += len(msg["content"]) + return tokens diff --git a/bot/openai/open_ai_image.py b/bot/openai/open_ai_image.py index 974bf8256..3ff56c175 100644 --- a/bot/openai/open_ai_image.py +++ b/bot/openai/open_ai_image.py @@ -15,7 +15,7 @@ def __init__(self): if conf().get("rate_limit_dalle"): self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50)) - def create_img(self, query, retry_count=0, api_key=None): + def create_img(self, query, retry_count=0, api_key=None, api_base=None): try: if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token(): return False, "请求太快了,请休息一下再问我吧" diff --git a/bot/session_manager.py b/bot/session_manager.py index 8d70886e0..a6e89f956 100644 --- a/bot/session_manager.py +++ b/bot/session_manager.py @@ -69,7 +69,7 @@ def session_query(self, query, session_id): total_tokens = session.discard_exceeding(max_tokens, None) logger.debug("prompt tokens used={}".format(total_tokens)) except Exception as e: - logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e))) + logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e))) return session def session_reply(self, reply, session_id, total_tokens=None): @@ -80,7 +80,7 @@ def session_reply(self, reply, session_id, total_tokens=None): tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) except Exception as e: - logger.debug("Exception when counting tokens precisely for session: {}".format(str(e))) + logger.warning("Exception when counting tokens precisely for session: {}".format(str(e))) return session def clear_session(self, session_id): diff --git a/bot/xunfei/xunfei_spark_bot.py b/bot/xunfei/xunfei_spark_bot.py index eac57f56c..30cab2fc2 100644 --- a/bot/xunfei/xunfei_spark_bot.py +++ b/bot/xunfei/xunfei_spark_bot.py @@ -3,7 +3,7 @@ import requests, json from bot.bot import Bot from bot.session_manager import SessionManager -from bot.baidu.baidu_wenxin_session import BaiduWenxinSession +from bot.chatgpt.chat_gpt_session import ChatGPTSession from bridge.context import ContextType, Context from bridge.reply import Reply, ReplyType from common.log import logger @@ -40,14 +40,20 @@ def __init__(self): self.app_id = conf().get("xunfei_app_id") self.api_key = conf().get("xunfei_api_key") self.api_secret = conf().get("xunfei_api_secret") - # 默认使用v2.0版本,1.5版本可设置为 general - self.domain = "generalv2" - # 默认使用v2.0版本,1.5版本可设置为 "ws://spark-api.xf-yun.com/v1.1/chat" - self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" + # 默认使用v2.0版本: "generalv2" + # Spark Lite请求地址(spark_url): wss://spark-api.xf-yun.com/v1.1/chat, 对应的domain参数为: "general" + # Spark V2.0请求地址(spark_url): wss://spark-api.xf-yun.com/v2.1/chat, 对应的domain参数为: "generalv2" + # Spark Pro 请求地址(spark_url): wss://spark-api.xf-yun.com/v3.1/chat, 对应的domain参数为: "generalv3" + # Spark Pro-128K请求地址(spark_url): wss://spark-api.xf-yun.com/chat/pro-128k, 对应的domain参数为: "pro-128k" + # Spark Max 请求地址(spark_url): wss://spark-api.xf-yun.com/v3.5/chat, 对应的domain参数为: "generalv3.5" + # Spark4.0 Ultra 请求地址(spark_url): wss://spark-api.xf-yun.com/v4.0/chat, 对应的domain参数为: "4.0Ultra" + # 后续模型更新,对应的参数可以参考官网文档获取:https://www.xfyun.cn/doc/spark/Web.html + self.domain = conf().get("xunfei_domain", "generalv3.5") + self.spark_url = conf().get("xunfei_spark_url", "wss://spark-api.xf-yun.com/v3.5/chat") self.host = urlparse(self.spark_url).netloc self.path = urlparse(self.spark_url).path # 和wenxin使用相同的session机制 - self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI) + self.sessions = SessionManager(ChatGPTSession, model=const.XUNFEI) def reply(self, query, context: Context = None) -> Reply: if context.type == ContextType.TEXT: @@ -56,7 +62,8 @@ def reply(self, query, context: Context = None) -> Reply: request_id = self.gen_request_id(session_id) reply_map[request_id] = "" session = self.sessions.session_query(query, session_id) - threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start() + threading.Thread(target=self.create_web_socket, + args=(session.messages, request_id)).start() depth = 0 time.sleep(0.1) t1 = time.time() @@ -83,20 +90,27 @@ def reply(self, query, context: Context = None) -> Reply: depth += 1 continue t2 = time.time() - logger.info(f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}") - self.sessions.session_reply(reply_map[request_id], session_id, usage.get("total_tokens")) + logger.info( + f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}" + ) + self.sessions.session_reply(reply_map[request_id], session_id, + usage.get("total_tokens")) reply = Reply(ReplyType.TEXT, reply_map[request_id]) del reply_map[request_id] return reply else: - reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) + reply = Reply(ReplyType.ERROR, + "Bot不支持处理{}类型的消息".format(context.type)) return reply def create_web_socket(self, prompt, session_id, temperature=0.5): logger.info(f"[XunFei] start connect, prompt={prompt}") websocket.enableTrace(False) wsUrl = self.create_url() - ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, + ws = websocket.WebSocketApp(wsUrl, + on_message=on_message, + on_error=on_error, + on_close=on_close, on_open=on_open) data_queue = queue.Queue(1000) queue_map[session_id] = data_queue @@ -108,7 +122,8 @@ def create_web_socket(self, prompt, session_id, temperature=0.5): ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def gen_request_id(self, session_id: str): - return session_id + "_" + str(int(time.time())) + "" + str(random.randint(0, 100)) + return session_id + "_" + str(int(time.time())) + "" + str( + random.randint(0, 100)) # 生成url def create_url(self): @@ -122,22 +137,21 @@ def create_url(self): signature_origin += "GET " + self.path + " HTTP/1.1" # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'), + signature_sha = hmac.new(self.api_secret.encode('utf-8'), + signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + signature_sha_base64 = base64.b64encode(signature_sha).decode( + encoding='utf-8') authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \ f'signature="{signature_sha_base64}"' - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + authorization = base64.b64encode( + authorization_origin.encode('utf-8')).decode(encoding='utf-8') # 将请求的鉴权参数组合为字典 - v = { - "authorization": authorization, - "date": date, - "host": self.host - } + v = {"authorization": authorization, "date": date, "host": self.host} # 拼接鉴权参数,生成url url = self.spark_url + '?' + urlencode(v) # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 @@ -190,11 +204,15 @@ def on_close(ws, one, two): # 收到websocket连接建立的处理 def on_open(ws): logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}") - thread.start_new_thread(run, (ws,)) + thread.start_new_thread(run, (ws, )) def run(ws, *args): - data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question, temperature=ws.temperature)) + data = json.dumps( + gen_params(appid=ws.appid, + domain=ws.domain, + question=ws.question, + temperature=ws.temperature)) ws.send(data) @@ -212,7 +230,8 @@ def on_message(ws, message): content = choices["text"][0]["content"] data_queue = queue_map.get(ws.session_id) if not data_queue: - logger.error(f"[XunFei] can't find data queue, session_id={ws.session_id}") + logger.error( + f"[XunFei] can't find data queue, session_id={ws.session_id}") return reply_item = ReplyItem(content) if status == 2: diff --git a/bot/zhipuai/zhipu_ai_image.py b/bot/zhipuai/zhipu_ai_image.py new file mode 100644 index 000000000..84eb5671e --- /dev/null +++ b/bot/zhipuai/zhipu_ai_image.py @@ -0,0 +1,29 @@ +from common.log import logger +from config import conf + + +# ZhipuAI提供的画图接口 + +class ZhipuAIImage(object): + def __init__(self): + from zhipuai import ZhipuAI + self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key")) + + def create_img(self, query, retry_count=0, api_key=None, api_base=None): + try: + if conf().get("rate_limit_dalle"): + return False, "请求太快了,请休息一下再问我吧" + logger.info("[ZHIPU_AI] image_query={}".format(query)) + response = self.client.images.generations( + prompt=query, + n=1, # 每次生成图片的数量 + model=conf().get("text_to_image") or "cogview-3", + size=conf().get("image_create_size", "1024x1024"), # 图片大小,可选有 256x256, 512x512, 1024x1024 + quality="standard", + ) + image_url = response.data[0].url + logger.info("[ZHIPU_AI] image_url={}".format(image_url)) + return True, image_url + except Exception as e: + logger.exception(e) + return False, "画图出现问题,请休息一下再问我吧" diff --git a/bot/zhipuai/zhipu_ai_session.py b/bot/zhipuai/zhipu_ai_session.py new file mode 100644 index 000000000..846d36a7b --- /dev/null +++ b/bot/zhipuai/zhipu_ai_session.py @@ -0,0 +1,53 @@ +from bot.session_manager import Session +from common.log import logger + + +class ZhipuAISession(Session): + def __init__(self, session_id, system_prompt=None, model="glm-4"): + super().__init__(session_id, system_prompt) + self.model = model + self.reset() + if not system_prompt: + logger.warn("[ZhiPu] `character_desc` can not be empty") + + def discard_exceeding(self, max_tokens, cur_tokens=None): + precise = True + try: + cur_tokens = self.calc_tokens() + except Exception as e: + precise = False + if cur_tokens is None: + raise e + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + while cur_tokens > max_tokens: + if len(self.messages) > 2: + self.messages.pop(1) + elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": + self.messages.pop(1) + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + break + elif len(self.messages) == 2 and self.messages[1]["role"] == "user": + logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) + break + else: + logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, + len(self.messages))) + break + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + return cur_tokens + + def calc_tokens(self): + return num_tokens_from_messages(self.messages, self.model) + + +def num_tokens_from_messages(messages, model): + tokens = 0 + for msg in messages: + tokens += len(msg["content"]) + return tokens diff --git a/bot/zhipuai/zhipuai_bot.py b/bot/zhipuai/zhipuai_bot.py new file mode 100644 index 000000000..d8eed4d35 --- /dev/null +++ b/bot/zhipuai/zhipuai_bot.py @@ -0,0 +1,149 @@ +# encoding:utf-8 + +import time + +import openai +import openai.error +from bot.bot import Bot +from bot.zhipuai.zhipu_ai_session import ZhipuAISession +from bot.zhipuai.zhipu_ai_image import ZhipuAIImage +from bot.session_manager import SessionManager +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf, load_config +from zhipuai import ZhipuAI + + +# ZhipuAI对话模型API +class ZHIPUAIBot(Bot, ZhipuAIImage): + def __init__(self): + super().__init__() + self.sessions = SessionManager(ZhipuAISession, model=conf().get("model") or "ZHIPU_AI") + self.args = { + "model": conf().get("model") or "glm-4", # 对话模型的名称 + "temperature": conf().get("temperature", 0.9), # 值在(0,1)之间(智谱AI 的温度不能取 0 或者 1) + "top_p": conf().get("top_p", 0.7), # 值在(0,1)之间(智谱AI 的 top_p 不能取 0 或者 1) + } + self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key")) + + def reply(self, query, context=None): + # acquire reply content + if context.type == ContextType.TEXT: + logger.info("[ZHIPU_AI] query={}".format(query)) + + session_id = context["session_id"] + reply = None + clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) + if query in clear_memory_commands: + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + elif query == "#更新配置": + load_config() + reply = Reply(ReplyType.INFO, "配置已更新") + if reply: + return reply + session = self.sessions.session_query(query, session_id) + logger.debug("[ZHIPU_AI] session query={}".format(session.messages)) + + api_key = context.get("openai_api_key") or openai.api_key + model = context.get("gpt_model") + new_args = None + if model: + new_args = self.args.copy() + new_args["model"] = model + # if context.get('stream'): + # # reply in stream + # return self.reply_text_stream(query, new_query, session_id) + + reply_content = self.reply_text(session, api_key, args=new_args) + logger.debug( + "[ZHIPU_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + session.messages, + session_id, + reply_content["content"], + reply_content["completion_tokens"], + ) + ) + if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + elif reply_content["completion_tokens"] > 0: + self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) + reply = Reply(ReplyType.TEXT, reply_content["content"]) + else: + reply = Reply(ReplyType.ERROR, reply_content["content"]) + logger.debug("[ZHIPU_AI] reply {} used 0 tokens.".format(reply_content)) + return reply + elif context.type == ContextType.IMAGE_CREATE: + ok, retstring = self.create_img(query, 0) + reply = None + if ok: + reply = Reply(ReplyType.IMAGE_URL, retstring) + else: + reply = Reply(ReplyType.ERROR, retstring) + return reply + + else: + reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) + return reply + + def reply_text(self, session: ZhipuAISession, api_key=None, args=None, retry_count=0) -> dict: + """ + call openai's ChatCompletion to get the answer + :param session: a conversation session + :param session_id: session id + :param retry_count: retry count + :return: {} + """ + try: + # if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): + # raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") + # if api_key == None, the default openai.api_key will be used + if args is None: + args = self.args + # response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args) + response = self.client.chat.completions.create(messages=session.messages, **args) + # logger.debug("[ZHIPU_AI] response={}".format(response)) + # logger.info("[ZHIPU_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) + + return { + "total_tokens": response.usage.total_tokens, + "completion_tokens": response.usage.completion_tokens, + "content": response.choices[0].message.content, + } + except Exception as e: + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if isinstance(e, openai.error.RateLimitError): + logger.warn("[ZHIPU_AI] RateLimitError: {}".format(e)) + result["content"] = "提问太快啦,请休息一下再问我吧" + if need_retry: + time.sleep(20) + elif isinstance(e, openai.error.Timeout): + logger.warn("[ZHIPU_AI] Timeout: {}".format(e)) + result["content"] = "我没有收到你的消息" + if need_retry: + time.sleep(5) + elif isinstance(e, openai.error.APIError): + logger.warn("[ZHIPU_AI] Bad Gateway: {}".format(e)) + result["content"] = "请再问我一次" + if need_retry: + time.sleep(10) + elif isinstance(e, openai.error.APIConnectionError): + logger.warn("[ZHIPU_AI] APIConnectionError: {}".format(e)) + result["content"] = "我连接不到你的网络" + if need_retry: + time.sleep(5) + else: + logger.exception("[ZHIPU_AI] Exception: {}".format(e), e) + need_retry = False + self.sessions.clear_session(session.session_id) + + if need_retry: + logger.warn("[ZHIPU_AI] 第{}次重试".format(retry_count + 1)) + return self.reply_text(session, api_key, args, retry_count + 1) + else: + return result diff --git a/bridge/bridge.py b/bridge/bridge.py index bceca100c..b7b3ebf84 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -18,26 +18,51 @@ def __init__(self): "text_to_voice": conf().get("text_to_voice", "google"), "translate": conf().get("translate", "baidu"), } - model_type = conf().get("model") - if model_type in ["text-davinci-003"]: - self.btype["chat"] = const.OPEN_AI - if conf().get("use_azure_chatgpt", False): - self.btype["chat"] = const.CHATGPTONAZURE - if model_type in ["wenxin", "wenxin-4"]: - self.btype["chat"] = const.BAIDU - if model_type in ["xunfei"]: - self.btype["chat"] = const.XUNFEI - if conf().get("use_linkai") and conf().get("linkai_api_key"): - self.btype["chat"] = const.LINKAI - if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]: - self.btype["voice_to_text"] = const.LINKAI - if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: - self.btype["text_to_voice"] = const.LINKAI - if model_type in ["claude"]: - self.btype["chat"] = const.CLAUDEAI + # 这边取配置的模型 + bot_type = conf().get("bot_type") + if bot_type: + self.btype["chat"] = bot_type + else: + model_type = conf().get("model") or const.GPT35 + if model_type in ["text-davinci-003"]: + self.btype["chat"] = const.OPEN_AI + if conf().get("use_azure_chatgpt", False): + self.btype["chat"] = const.CHATGPTONAZURE + if model_type in ["wenxin", "wenxin-4"]: + self.btype["chat"] = const.BAIDU + if model_type in ["xunfei"]: + self.btype["chat"] = const.XUNFEI + if model_type in [const.QWEN]: + self.btype["chat"] = const.QWEN + if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]: + self.btype["chat"] = const.QWEN_DASHSCOPE + if model_type and model_type.startswith("gemini"): + self.btype["chat"] = const.GEMINI + if model_type in [const.ZHIPU_AI]: + self.btype["chat"] = const.ZHIPU_AI + if model_type and model_type.startswith("claude-3"): + self.btype["chat"] = const.CLAUDEAPI + + if model_type in ["claude"]: + self.btype["chat"] = const.CLAUDEAI + + if model_type in ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]: + self.btype["chat"] = const.MOONSHOT + + if model_type in ["abab6.5-chat"]: + self.btype["chat"] = const.MiniMax + + if conf().get("use_linkai") and conf().get("linkai_api_key"): + self.btype["chat"] = const.LINKAI + if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]: + self.btype["voice_to_text"] = const.LINKAI + if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: + self.btype["text_to_voice"] = const.LINKAI + self.bots = {} self.chat_bots = {} + # 模型对应的接口 def get_bot(self, typename): if self.bots.get(typename) is None: logger.info("create bot {} for {}".format(self.btype[typename], typename)) diff --git a/bridge/context.py b/bridge/context.py index 1e5958c1d..04d63209c 100644 --- a/bridge/context.py +++ b/bridge/context.py @@ -16,6 +16,8 @@ class ContextType(Enum): JOIN_GROUP = 20 # 加入群聊 PATPAT = 21 # 拍了拍 FUNCTION = 22 # 函数调用 + EXIT_GROUP = 23 #退出 + def __str__(self): return self.name diff --git a/bridge/reply.py b/bridge/reply.py index 00314845e..f2293bdfb 100644 --- a/bridge/reply.py +++ b/bridge/reply.py @@ -11,7 +11,7 @@ class ReplyType(Enum): VIDEO_URL = 5 # 视频URL FILE = 6 # 文件 CARD = 7 # 微信名片,仅支持ntchat - InviteRoom = 8 # 邀请好友进群 + INVITE_ROOM = 8 # 邀请好友进群 INFO = 9 ERROR = 10 TEXT_ = 11 # 强制文本 diff --git a/channel/channel.py b/channel/channel.py index 6464d771e..c22534273 100644 --- a/channel/channel.py +++ b/channel/channel.py @@ -8,6 +8,7 @@ class Channel(object): + channel_type = "" NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE] def startup(self): diff --git a/channel/channel_factory.py b/channel/channel_factory.py index 8c45045c0..c2c6937cf 100644 --- a/channel/channel_factory.py +++ b/channel/channel_factory.py @@ -1,40 +1,45 @@ """ channel factory """ +from common import const +from .channel import Channel -def create_channel(channel_type): +def create_channel(channel_type) -> Channel: """ create a channel instance :param channel_type: channel type code :return: channel instance """ + ch = Channel() if channel_type == "wx": from channel.wechat.wechat_channel import WechatChannel - - return WechatChannel() + ch = WechatChannel() elif channel_type == "wxy": from channel.wechat.wechaty_channel import WechatyChannel - - return WechatyChannel() + ch = WechatyChannel() elif channel_type == "terminal": from channel.terminal.terminal_channel import TerminalChannel - - return TerminalChannel() + ch = TerminalChannel() elif channel_type == "wechatmp": from channel.wechatmp.wechatmp_channel import WechatMPChannel - - return WechatMPChannel(passive_reply=True) + ch = WechatMPChannel(passive_reply=True) elif channel_type == "wechatmp_service": from channel.wechatmp.wechatmp_channel import WechatMPChannel - - return WechatMPChannel(passive_reply=False) + ch = WechatMPChannel(passive_reply=False) elif channel_type == "wechatcom_app": from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel - - return WechatComAppChannel() + ch = WechatComAppChannel() elif channel_type == "wework": from channel.wework.wework_channel import WeworkChannel - - return WeworkChannel() - raise RuntimeError + ch = WeworkChannel() + elif channel_type == const.FEISHU: + from channel.feishu.feishu_channel import FeiShuChanel + ch = FeiShuChanel() + elif channel_type == const.DINGTALK: + from channel.dingtalk.dingtalk_channel import DingTalkChanel + ch = DingTalkChanel() + else: + raise RuntimeError + ch.channel_type = channel_type + return ch diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 648eaaccc..27f3af0d7 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -9,8 +9,7 @@ from bridge.reply import * from channel.channel import Channel from common.dequeue import Dequeue -from common.log import logger -from config import conf +from common import memory from plugins import * try: @@ -18,6 +17,8 @@ except Exception as e: pass +handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 + # 抽象类, 它包含了与消息通道无关的通用处理逻辑 class ChatChannel(Channel): @@ -26,7 +27,6 @@ class ChatChannel(Channel): futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 lock = threading.Lock() # 用于控制对sessions的访问 - handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 def __init__(self): _thread = threading.Thread(target=self.consume) @@ -74,6 +74,7 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): ): session_id = group_id else: + logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}") return None context["session_id"] = session_id context["receiver"] = group_id @@ -85,16 +86,17 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): if e_context.is_pass() or context is None: return context if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True): - logger.debug("[WX]self message skipped") + logger.debug("[chat_channel]self message skipped") return None # 消息内容匹配过程,并处理content if ctype == ContextType.TEXT: if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息 logger.debug(content) - logger.debug("[WX]reference query skipped") + logger.debug("[chat_channel]reference query skipped") return None + nick_name_black_list = conf().get("nick_name_black_list", []) if context.get("isgroup", False): # 群聊 # 校验关键字 match_prefix = check_prefix(content, conf().get("group_chat_prefix")) @@ -106,9 +108,16 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): if match_prefix: content = content.replace(match_prefix, "", 1).strip() if context["msg"].is_at: - logger.info("[WX]receive group at") + nick_name = context["msg"].actual_user_nickname + if nick_name and nick_name in nick_name_black_list: + # 黑名单过滤 + logger.warning(f"[chat_channel] Nickname {nick_name} in In BlackList, ignore") + return None + + logger.info("[chat_channel]receive group at") if not conf().get("group_at_off", False): flag = True + self.name = self.name if self.name is not None else "" # 部分渠道self.name可能没有赋值 pattern = f"@{re.escape(self.name)}(\u2005|\u0020)" subtract_res = re.sub(pattern, r"", content) if isinstance(context["msg"].at_list, list): @@ -122,9 +131,15 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): content = subtract_res if not flag: if context["origin_ctype"] == ContextType.VOICE: - logger.info("[WX]receive group voice, but checkprefix didn't match") + logger.info("[chat_channel]receive group voice, but checkprefix didn't match") return None else: # 单聊 + nick_name = context["msg"].from_user_nickname + if nick_name and nick_name in nick_name_black_list: + # 黑名单过滤 + logger.warning(f"[chat_channel] Nickname '{nick_name}' in In BlackList, ignore") + return None + match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""])) if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 content = content.replace(match_prefix, "", 1).strip() @@ -133,7 +148,7 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): else: return None content = content.strip() - img_match_prefix = check_prefix(content, conf().get("image_create_prefix")) + img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""])) if img_match_prefix: content = content.replace(img_match_prefix, "", 1) context.type = ContextType.IMAGE_CREATE @@ -145,22 +160,23 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): elif context.type == ContextType.VOICE: if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: context["desire_rtype"] = ReplyType.VOICE - return context def _handle(self, context: Context): if context is None or not context.content: return - logger.debug("[WX] ready to handle context: {}".format(context)) + logger.debug("[chat_channel] ready to handle context: {}".format(context)) # reply的构建步骤 reply = self._generate_reply(context) - logger.debug("[WX] ready to decorate reply: {}".format(reply)) + logger.debug("[chat_channel] ready to decorate reply: {}".format(reply)) + # reply的包装步骤 - reply = self._decorate_reply(context, reply) + if reply and reply.content: + reply = self._decorate_reply(context, reply) - # reply的发送步骤 - self._send_reply(context, reply) + # reply的发送步骤 + self._send_reply(context, reply) def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: e_context = PluginManager().emit_event( @@ -171,10 +187,9 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: ) reply = e_context["reply"] if not e_context.is_pass(): - logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content)) - if e_context.is_break(): - context["generate_breaked_by"] = e_context["breaked_by"] + logger.debug("[chat_channel] ready to handle context: type={}, content={}".format(context.type, context.content)) if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 + context["channel"] = e_context["channel"] reply = super().build_reply_content(context.content, context) elif context.type == ContextType.VOICE: # 语音消息 cmsg = context["msg"] @@ -184,7 +199,7 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: try: any_to_wav(file_path, wav_path) except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别 - logger.warning("[WX]any to wav error, use raw path. " + str(e)) + logger.warning("[chat_channel]any to wav error, use raw path. " + str(e)) wav_path = file_path # 语音识别 reply = super().build_voice_to_text(wav_path) @@ -195,7 +210,7 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: os.remove(wav_path) except Exception as e: pass - # logger.warning("[WX]delete temp file error: " + str(e)) + # logger.warning("[chat_channel]delete temp file error: " + str(e)) if reply.type == ReplyType.TEXT: new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs) @@ -204,14 +219,16 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: else: return elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑 - cmsg = context["msg"] - cmsg.prepare() + memory.USER_IMAGE_CACHE[context["session_id"]] = { + "path": context.content, + "msg": context.get("msg") + } elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑 pass elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑 pass else: - logger.error("[WX] unknown context type: {}".format(context.type)) + logger.warning("[chat_channel] unknown context type: {}".format(context.type)) return return reply @@ -227,7 +244,7 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply: desire_rtype = context.get("desire_rtype") if not e_context.is_pass() and reply and reply.type: if reply.type in self.NOT_SUPPORT_REPLYTYPE: - logger.error("[WX]reply type not support: " + str(reply.type)) + logger.error("[chat_channel]reply type not support: " + str(reply.type)) reply.type = ReplyType.ERROR reply.content = "不支持发送的消息类型: " + str(reply.type) @@ -237,7 +254,8 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply: reply = super().build_text_to_voice(reply.content) return self._decorate_reply(context, reply) if context.get("isgroup", False): - reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip() + if not context.get("no_need_at", False): + reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip() reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "") else: reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "") @@ -247,10 +265,10 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply: elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL: pass else: - logger.error("[WX] unknown reply type: {}".format(reply.type)) + logger.error("[chat_channel] unknown reply type: {}".format(reply.type)) return if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: - logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type)) + logger.warning("[chat_channel] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type)) return reply def _send_reply(self, context: Context, reply: Reply): @@ -263,14 +281,14 @@ def _send_reply(self, context: Context, reply: Reply): ) reply = e_context["reply"] if not e_context.is_pass() and reply and reply.type: - logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context)) + logger.debug("[chat_channel] ready to send reply: {}, context: {}".format(reply, context)) self._send(reply, context) def _send(self, reply: Reply, context: Context, retry_cnt=0): try: self.send(reply, context) except Exception as e: - logger.error("[WX] sendMsg error: {}".format(str(e))) + logger.error("[chat_channel] sendMsg error: {}".format(str(e))) if isinstance(e, NotImplementedError): return logger.exception(e) @@ -324,8 +342,8 @@ def consume(self): if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除 if not context_queue.empty(): context = context_queue.get() - logger.debug("[WX] consume context: {}".format(context)) - future: Future = self.handler_pool.submit(self._handle, context) + logger.debug("[chat_channel] consume context: {}".format(context)) + future: Future = handler_pool.submit(self._handle, context) future.add_done_callback(self._thread_pool_callback(session_id, context=context)) if session_id not in self.futures: self.futures[session_id] = [] diff --git a/channel/dingtalk/dingtalk_channel.py b/channel/dingtalk/dingtalk_channel.py new file mode 100644 index 000000000..6c99e5f6c --- /dev/null +++ b/channel/dingtalk/dingtalk_channel.py @@ -0,0 +1,225 @@ +""" +钉钉通道接入 + +@author huiwen +@Date 2023/11/28 +""" +import copy +import json +# -*- coding=utf-8 -*- +import logging +import time + +import dingtalk_stream +from dingtalk_stream import AckMessage +from dingtalk_stream.card_replier import AICardReplier +from dingtalk_stream.card_replier import AICardStatus +from dingtalk_stream.card_replier import CardReplier + +from bridge.context import Context, ContextType +from bridge.reply import Reply, ReplyType +from channel.chat_channel import ChatChannel +from channel.dingtalk.dingtalk_message import DingTalkMessage +from common.expired_dict import ExpiredDict +from common.log import logger +from common.singleton import singleton +from common.time_check import time_checker +from config import conf + + +class CustomAICardReplier(CardReplier): + def __init__(self, dingtalk_client, incoming_message): + super(AICardReplier, self).__init__(dingtalk_client, incoming_message) + + def start( + self, + card_template_id: str, + card_data: dict, + recipients: list = None, + support_forward: bool = True, + ) -> str: + """ + AI卡片的创建接口 + :param support_forward: + :param recipients: + :param card_template_id: + :param card_data: + :return: + """ + card_data_with_status = copy.deepcopy(card_data) + card_data_with_status["flowStatus"] = AICardStatus.PROCESSING + return self.create_and_send_card( + card_template_id, + card_data_with_status, + at_sender=True, + at_all=False, + recipients=recipients, + support_forward=support_forward, + ) + + +# 对 AICardReplier 进行猴子补丁 +AICardReplier.start = CustomAICardReplier.start + + +def _check(func): + def wrapper(self, cmsg: DingTalkMessage): + msgId = cmsg.msg_id + if msgId in self.receivedMsgs: + logger.info("DingTalk message {} already received, ignore".format(msgId)) + return + self.receivedMsgs[msgId] = True + create_time = cmsg.create_time # 消息时间戳 + if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 + logger.debug("[DingTalk] History message {} skipped".format(msgId)) + return + if cmsg.my_msg and not cmsg.is_group: + logger.debug("[DingTalk] My message {} skipped".format(msgId)) + return + return func(self, cmsg) + + return wrapper + + +@singleton +class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler): + dingtalk_client_id = conf().get('dingtalk_client_id') + dingtalk_client_secret = conf().get('dingtalk_client_secret') + + def setup_logger(self): + logger = logging.getLogger() + handler = logging.StreamHandler() + handler.setFormatter( + logging.Formatter('%(asctime)s %(name)-8s %(levelname)-8s %(message)s [%(filename)s:%(lineno)d]')) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + return logger + + def __init__(self): + super().__init__() + super(dingtalk_stream.ChatbotHandler, self).__init__() + self.logger = self.setup_logger() + # 历史消息id暂存,用于幂等控制 + self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600)) + logger.info("[DingTalk] client_id={}, client_secret={} ".format( + self.dingtalk_client_id, self.dingtalk_client_secret)) + # 无需群校验和前缀 + conf()["group_name_white_list"] = ["ALL_GROUP"] + # 单聊无需前缀 + conf()["single_chat_prefix"] = [""] + + def startup(self): + credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret) + client = dingtalk_stream.DingTalkStreamClient(credential) + client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self) + client.start_forever() + + async def process(self, callback: dingtalk_stream.CallbackMessage): + try: + incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data) + image_download_handler = self # 传入方法所在的类实例 + dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler) + + if dingtalk_msg.is_group: + self.handle_group(dingtalk_msg) + else: + self.handle_single(dingtalk_msg) + return AckMessage.STATUS_OK, 'OK' + except Exception as e: + logger.error(f"dingtalk process error={e}") + return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR' + + @time_checker + @_check + def handle_single(self, cmsg: DingTalkMessage): + # 处理单聊消息 + if cmsg.ctype == ContextType.VOICE: + logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content)) + elif cmsg.ctype == ContextType.IMAGE: + logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content)) + elif cmsg.ctype == ContextType.IMAGE_CREATE: + logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content)) + elif cmsg.ctype == ContextType.PATPAT: + logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content)) + elif cmsg.ctype == ContextType.TEXT: + logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content)) + else: + logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content)) + context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg) + if context: + self.produce(context) + + + @time_checker + @_check + def handle_group(self, cmsg: DingTalkMessage): + # 处理群聊消息 + if cmsg.ctype == ContextType.VOICE: + logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content)) + elif cmsg.ctype == ContextType.IMAGE: + logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content)) + elif cmsg.ctype == ContextType.IMAGE_CREATE: + logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content)) + elif cmsg.ctype == ContextType.PATPAT: + logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content)) + elif cmsg.ctype == ContextType.TEXT: + logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content)) + else: + logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content)) + context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) + context['no_need_at'] = True + if context: + self.produce(context) + + + def send(self, reply: Reply, context: Context): + receiver = context["receiver"] + isgroup = context.kwargs['msg'].is_group + incoming_message = context.kwargs['msg'].incoming_message + + if conf().get("dingtalk_card_enabled"): + logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver)) + def reply_with_text(): + self.reply_text(reply.content, incoming_message) + def reply_with_at_text(): + self.reply_text("📢 您有一条新的消息,请查看。", incoming_message) + def reply_with_ai_markdown(): + button_list, markdown_content = self.generate_button_markdown_content(context, reply) + self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI生成", "",[incoming_message.sender_staff_id]) + + if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]: + if isgroup: + reply_with_ai_markdown() + reply_with_at_text() + else: + reply_with_ai_markdown() + else: + # 暂不支持其它类型消息回复 + reply_with_text() + else: + self.reply_text(reply.content, incoming_message) + + + def generate_button_markdown_content(self, context, reply): + image_url = context.kwargs.get("image_url") + promptEn = context.kwargs.get("promptEn") + reply_text = reply.content + button_list = [] + markdown_content = f""" +{reply.content} + """ + if image_url is not None and promptEn is not None: + button_list = [ + {"text": "查看原图", "url": image_url, "iosUrl": image_url, "color": "blue"} + ] + markdown_content = f""" +{promptEn} + +!["图片"]({image_url}) + +{reply_text} + + """ + logger.debug(f"[Dingtalk] generate_button_markdown_content, button_list={button_list} , markdown_content={markdown_content}") + + return button_list, markdown_content diff --git a/channel/dingtalk/dingtalk_message.py b/channel/dingtalk/dingtalk_message.py new file mode 100644 index 000000000..c069a1bb6 --- /dev/null +++ b/channel/dingtalk/dingtalk_message.py @@ -0,0 +1,84 @@ +import os + +import requests +from dingtalk_stream import ChatbotMessage + +from bridge.context import ContextType +from channel.chat_message import ChatMessage +# -*- coding=utf-8 -*- +from common.log import logger +from common.tmp_dir import TmpDir + + +class DingTalkMessage(ChatMessage): + def __init__(self, event: ChatbotMessage, image_download_handler): + super().__init__(event) + self.image_download_handler = image_download_handler + self.msg_id = event.message_id + self.message_type = event.message_type + self.incoming_message = event + self.sender_staff_id = event.sender_staff_id + self.other_user_id = event.conversation_id + self.create_time = event.create_at + self.image_content = event.image_content + self.rich_text_content = event.rich_text_content + if event.conversation_type == "1": + self.is_group = False + else: + self.is_group = True + + if self.message_type == "text": + self.ctype = ContextType.TEXT + + self.content = event.text.content.strip() + elif self.message_type == "audio": + # 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理 + self.content = event.extensions['content']['recognition'].strip() + self.ctype = ContextType.TEXT + elif (self.message_type == 'picture') or (self.message_type == 'richText'): + self.ctype = ContextType.IMAGE + # 钉钉图片类型或富文本类型消息处理 + image_list = event.get_image_list() + if len(image_list) > 0: + download_code = image_list[0] + download_url = image_download_handler.get_image_download_url(download_code) + self.content = download_image_file(download_url, TmpDir().path()) + else: + logger.debug(f"[Dingtalk] messageType :{self.message_type} , imageList isEmpty") + + if self.is_group: + self.from_user_id = event.conversation_id + self.actual_user_id = event.sender_id + self.is_at = True + else: + self.from_user_id = event.sender_id + self.actual_user_id = event.sender_id + self.to_user_id = event.chatbot_user_id + self.other_user_nickname = event.conversation_title + + +def download_image_file(image_url, temp_dir): + headers = { + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36' + } + # 设置代理 + # self.proxies + # , proxies=self.proxies + response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5) + if response.status_code == 200: + + # 生成文件名 + file_name = image_url.split("/")[-1].split("?")[0] + + # 检查临时目录是否存在,如果不存在则创建 + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + + # 将文件保存到临时目录 + file_path = os.path.join(temp_dir, file_name) + with open(file_path, 'wb') as file: + file.write(response.content) + return file_path + else: + logger.info(f"[Dingtalk] Failed to download image file, {response.content}") + return None diff --git a/channel/feishu/feishu_channel.py b/channel/feishu/feishu_channel.py new file mode 100644 index 000000000..37837d4e7 --- /dev/null +++ b/channel/feishu/feishu_channel.py @@ -0,0 +1,254 @@ +""" +飞书通道接入 + +@author Saboteur7 +@Date 2023/11/19 +""" + +# -*- coding=utf-8 -*- +import uuid + +import requests +import web +from channel.feishu.feishu_message import FeishuMessage +from bridge.context import Context +from bridge.reply import Reply, ReplyType +from common.log import logger +from common.singleton import singleton +from config import conf +from common.expired_dict import ExpiredDict +from bridge.context import ContextType +from channel.chat_channel import ChatChannel, check_prefix +from common import utils +import json +import os + +URL_VERIFICATION = "url_verification" + + +@singleton +class FeiShuChanel(ChatChannel): + feishu_app_id = conf().get('feishu_app_id') + feishu_app_secret = conf().get('feishu_app_secret') + feishu_token = conf().get('feishu_token') + + def __init__(self): + super().__init__() + # 历史消息id暂存,用于幂等控制 + self.receivedMsgs = ExpiredDict(60 * 60 * 7.1) + logger.info("[FeiShu] app_id={}, app_secret={} verification_token={}".format( + self.feishu_app_id, self.feishu_app_secret, self.feishu_token)) + # 无需群校验和前缀 + conf()["group_name_white_list"] = ["ALL_GROUP"] + conf()["single_chat_prefix"] = [""] + + def startup(self): + urls = ( + '/', 'channel.feishu.feishu_channel.FeishuController' + ) + app = web.application(urls, globals(), autoreload=False) + port = conf().get("feishu_port", 9891) + web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) + + def send(self, reply: Reply, context: Context): + msg = context.get("msg") + is_group = context["isgroup"] + if msg: + access_token = msg.access_token + else: + access_token = self.fetch_access_token() + headers = { + "Authorization": "Bearer " + access_token, + "Content-Type": "application/json", + } + msg_type = "text" + logger.info(f"[FeiShu] start send reply message, type={context.type}, content={reply.content}") + reply_content = reply.content + content_key = "text" + if reply.type == ReplyType.IMAGE_URL: + # 图片上传 + reply_content = self._upload_image_url(reply.content, access_token) + if not reply_content: + logger.warning("[FeiShu] upload file failed") + return + msg_type = "image" + content_key = "image_key" + if is_group: + # 群聊中直接回复 + url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply" + data = { + "msg_type": msg_type, + "content": json.dumps({content_key: reply_content}) + } + res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10)) + else: + url = "https://open.feishu.cn/open-apis/im/v1/messages" + params = {"receive_id_type": context.get("receive_id_type") or "open_id"} + data = { + "receive_id": context.get("receiver"), + "msg_type": msg_type, + "content": json.dumps({content_key: reply_content}) + } + res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10)) + res = res.json() + if res.get("code") == 0: + logger.info(f"[FeiShu] send message success") + else: + logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}") + + + def fetch_access_token(self) -> str: + url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/" + headers = { + "Content-Type": "application/json" + } + req_body = { + "app_id": self.feishu_app_id, + "app_secret": self.feishu_app_secret + } + data = bytes(json.dumps(req_body), encoding='utf8') + response = requests.post(url=url, data=data, headers=headers) + if response.status_code == 200: + res = response.json() + if res.get("code") != 0: + logger.error(f"[FeiShu] get tenant_access_token error, code={res.get('code')}, msg={res.get('msg')}") + return "" + else: + return res.get("tenant_access_token") + else: + logger.error(f"[FeiShu] fetch token error, res={response}") + + + def _upload_image_url(self, img_url, access_token): + logger.debug(f"[WX] start download image, img_url={img_url}") + response = requests.get(img_url) + suffix = utils.get_path_suffix(img_url) + temp_name = str(uuid.uuid4()) + "." + suffix + if response.status_code == 200: + # 将图片内容保存为临时文件 + with open(temp_name, "wb") as file: + file.write(response.content) + + # upload + upload_url = "https://open.feishu.cn/open-apis/im/v1/images" + data = { + 'image_type': 'message' + } + headers = { + 'Authorization': f'Bearer {access_token}', + } + with open(temp_name, "rb") as file: + upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers) + logger.info(f"[FeiShu] upload file, res={upload_response.content}") + os.remove(temp_name) + return upload_response.json().get("data").get("image_key") + + + +class FeishuController: + # 类常量 + FAILED_MSG = '{"success": false}' + SUCCESS_MSG = '{"success": true}' + MESSAGE_RECEIVE_TYPE = "im.message.receive_v1" + + def GET(self): + return "Feishu service start success!" + + def POST(self): + try: + channel = FeiShuChanel() + + request = json.loads(web.data().decode("utf-8")) + logger.debug(f"[FeiShu] receive request: {request}") + + # 1.事件订阅回调验证 + if request.get("type") == URL_VERIFICATION: + varify_res = {"challenge": request.get("challenge")} + return json.dumps(varify_res) + + # 2.消息接收处理 + # token 校验 + header = request.get("header") + if not header or header.get("token") != channel.feishu_token: + return self.FAILED_MSG + + # 处理消息事件 + event = request.get("event") + if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event: + if not event.get("message") or not event.get("sender"): + logger.warning(f"[FeiShu] invalid message, msg={request}") + return self.FAILED_MSG + msg = event.get("message") + + # 幂等判断 + if channel.receivedMsgs.get(msg.get("message_id")): + logger.warning(f"[FeiShu] repeat msg filtered, event_id={header.get('event_id')}") + return self.SUCCESS_MSG + channel.receivedMsgs[msg.get("message_id")] = True + + is_group = False + chat_type = msg.get("chat_type") + if chat_type == "group": + if not msg.get("mentions") and msg.get("message_type") == "text": + # 群聊中未@不响应 + return self.SUCCESS_MSG + if msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get("message_type") == "text": + # 不是@机器人,不响应 + return self.SUCCESS_MSG + # 群聊 + is_group = True + receive_id_type = "chat_id" + elif chat_type == "p2p": + receive_id_type = "open_id" + else: + logger.warning("[FeiShu] message ignore") + return self.SUCCESS_MSG + # 构造飞书消息对象 + feishu_msg = FeishuMessage(event, is_group=is_group, access_token=channel.fetch_access_token()) + if not feishu_msg: + return self.SUCCESS_MSG + + context = self._compose_context( + feishu_msg.ctype, + feishu_msg.content, + isgroup=is_group, + msg=feishu_msg, + receive_id_type=receive_id_type, + no_need_at=True + ) + if context: + channel.produce(context) + logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}") + return self.SUCCESS_MSG + + except Exception as e: + logger.error(e) + return self.FAILED_MSG + + def _compose_context(self, ctype: ContextType, content, **kwargs): + context = Context(ctype, content) + context.kwargs = kwargs + if "origin_ctype" not in context: + context["origin_ctype"] = ctype + + cmsg = context["msg"] + context["session_id"] = cmsg.from_user_id + context["receiver"] = cmsg.other_user_id + + if ctype == ContextType.TEXT: + # 1.文本请求 + # 图片生成处理 + img_match_prefix = check_prefix(content, conf().get("image_create_prefix")) + if img_match_prefix: + content = content.replace(img_match_prefix, "", 1) + context.type = ContextType.IMAGE_CREATE + else: + context.type = ContextType.TEXT + context.content = content.strip() + + elif context.type == ContextType.VOICE: + # 2.语音请求 + if "desire_rtype" not in context and conf().get("voice_reply_voice"): + context["desire_rtype"] = ReplyType.VOICE + + return context diff --git a/channel/feishu/feishu_message.py b/channel/feishu/feishu_message.py new file mode 100644 index 000000000..e2054c127 --- /dev/null +++ b/channel/feishu/feishu_message.py @@ -0,0 +1,63 @@ +from bridge.context import ContextType +from channel.chat_message import ChatMessage +import json +import requests +from common.log import logger +from common.tmp_dir import TmpDir +from common import utils + + +class FeishuMessage(ChatMessage): + def __init__(self, event: dict, is_group=False, access_token=None): + super().__init__(event) + msg = event.get("message") + sender = event.get("sender") + self.access_token = access_token + self.msg_id = msg.get("message_id") + self.create_time = msg.get("create_time") + self.is_group = is_group + msg_type = msg.get("message_type") + + if msg_type == "text": + self.ctype = ContextType.TEXT + content = json.loads(msg.get('content')) + self.content = content.get("text").strip() + elif msg_type == "file": + self.ctype = ContextType.FILE + content = json.loads(msg.get("content")) + file_key = content.get("file_key") + file_name = content.get("file_name") + + self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name) + + def _download_file(): + # 如果响应状态码是200,则将响应内容写入本地文件 + url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}" + headers = { + "Authorization": "Bearer " + access_token, + } + params = { + "type": "file" + } + response = requests.get(url=url, headers=headers, params=params) + if response.status_code == 200: + with open(self.content, "wb") as f: + f.write(response.content) + else: + logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}") + self._prepare_fn = _download_file + else: + raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type)) + + self.from_user_id = sender.get("sender_id").get("open_id") + self.to_user_id = event.get("app_id") + if is_group: + # 群聊 + self.other_user_id = msg.get("chat_id") + self.actual_user_id = self.from_user_id + self.content = self.content.replace("@_user_1", "").strip() + self.actual_user_nickname = "" + else: + # 私聊 + self.other_user_id = self.from_user_id + self.actual_user_id = self.from_user_id diff --git a/channel/terminal/terminal_channel.py b/channel/terminal/terminal_channel.py index 9a413dcff..9b64eb475 100644 --- a/channel/terminal/terminal_channel.py +++ b/channel/terminal/terminal_channel.py @@ -78,6 +78,7 @@ def startup(self): prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)) + context["isgroup"] = False if context: self.produce(context) else: diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index 0989a8580..8b4455481 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -9,17 +9,18 @@ import os import threading import time - import requests from bridge.context import * from bridge.reply import * from channel.chat_channel import ChatChannel +from channel import chat_channel from channel.wechat.wechat_message import * from common.expired_dict import ExpiredDict from common.log import logger from common.singleton import singleton from common.time_check import time_checker +from common.utils import convert_webp_to_png from config import conf, get_appdata_dir from lib import itchat from lib.itchat.content import * @@ -95,7 +96,7 @@ def qrCallback(uuid, status, qrcode): print(qr_api4) print(qr_api2) print(qr_api1) - + _send_qr_code([qr_api3, qr_api4, qr_api2, qr_api1]) qr = qrcode.QRCode(border=1) qr.add_data(url) qr.make(fit=True) @@ -108,24 +109,47 @@ class WechatChannel(ChatChannel): def __init__(self): super().__init__() - self.receivedMsgs = ExpiredDict(60 * 60) + self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600)) + self.auto_login_times = 0 def startup(self): - itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 - # login by scan QRCode - hotReload = conf().get("hot_reload", False) - status_path = os.path.join(get_appdata_dir(), "itchat.pkl") - itchat.auto_login( - enableCmdQR=2, - hotReload=hotReload, - statusStorageDir=status_path, - qrCallback=qrCallback, - ) - self.user_id = itchat.instance.storageClass.userName - self.name = itchat.instance.storageClass.nickName - logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) - # start message listener - itchat.run() + try: + itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 + # login by scan QRCode + hotReload = conf().get("hot_reload", False) + status_path = os.path.join(get_appdata_dir(), "itchat.pkl") + itchat.auto_login( + enableCmdQR=2, + hotReload=hotReload, + statusStorageDir=status_path, + qrCallback=qrCallback, + exitCallback=self.exitCallback, + loginCallback=self.loginCallback + ) + self.user_id = itchat.instance.storageClass.userName + self.name = itchat.instance.storageClass.nickName + logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) + # start message listener + itchat.run() + except Exception as e: + logger.exception(e) + + def exitCallback(self): + try: + from common.linkai_client import chat_client + if chat_client.client_id and conf().get("use_linkai"): + _send_logout() + time.sleep(2) + self.auto_login_times += 1 + if self.auto_login_times < 100: + chat_channel.handler_pool._shutdown = False + self.startup() + except Exception as e: + pass + + def loginCallback(self): + logger.debug("Login success") + _send_login_success() # handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复 # Context包含了消息的所有信息,包括以下属性 @@ -138,7 +162,6 @@ def startup(self): # msg: ChatMessage消息对象 # origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则 # desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复 - @time_checker @_check def handle_single(self, cmsg: ChatMessage): @@ -170,7 +193,7 @@ def handle_group(self, cmsg: ChatMessage): logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) elif cmsg.ctype == ContextType.IMAGE: logger.debug("[WX]receive image for group msg: {}".format(cmsg.content)) - elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.ACCEPT_FRIEND]: + elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.ACCEPT_FRIEND, ContextType.EXIT_GROUP]: logger.debug("[WX]receive note msg: {}".format(cmsg.content)) elif cmsg.ctype == ContextType.TEXT: # logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) @@ -179,7 +202,7 @@ def handle_group(self, cmsg: ChatMessage): logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}") else: logger.debug("[WX]receive group msg: {}".format(cmsg.content)) - context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) + context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg, no_need_at=conf().get("no_need_at", False)) if context: self.produce(context) @@ -206,6 +229,12 @@ def send(self, reply: Reply, context: Context): image_storage.write(block) logger.info(f"[WX] download image success, size={size}, img_url={img_url}") image_storage.seek(0) + if ".webp" in img_url: + try: + image_storage = convert_webp_to_png(image_storage) + except Exception as e: + logger.error(f"Failed to convert image: {e}") + return itchat.send_image(image_storage, toUserName=receiver) logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver)) elif reply.type == ReplyType.IMAGE: # 从文件读取图片 @@ -234,3 +263,30 @@ def send(self, reply: Reply, context: Context): video_storage.seek(0) itchat.send_video(video_storage, toUserName=receiver) logger.info("[WX] sendVideo url={}, receiver={}".format(video_url, receiver)) + +def _send_login_success(): + try: + from common.linkai_client import chat_client + if chat_client.client_id: + chat_client.send_login_success() + except Exception as e: + pass + + +def _send_logout(): + try: + from common.linkai_client import chat_client + if chat_client.client_id: + chat_client.send_logout() + except Exception as e: + pass + + +def _send_qr_code(qrcode_list: list): + try: + from common.linkai_client import chat_client + if chat_client.client_id: + chat_client.send_qrcode(qrcode_list) + except Exception as e: + pass + diff --git a/channel/wechat/wechat_message.py b/channel/wechat/wechat_message.py index e1020186d..e7109d658 100644 --- a/channel/wechat/wechat_message.py +++ b/channel/wechat/wechat_message.py @@ -14,6 +14,11 @@ def __init__(self, itchat_msg, is_group=False): self.create_time = itchat_msg["CreateTime"] self.is_group = is_group + notes_join_group = ["加入群聊", "加入了群聊", "invited", "joined"] # 可通过添加对应语言的加入群聊通知中的关键词适配更多 + notes_bot_join_group = ["邀请你", "invited you", "You've joined", "你通过扫描"] + notes_exit_group = ["移出了群聊", "removed"] # 可通过添加对应语言的踢出群聊通知中的关键词适配更多 + notes_patpat = ["拍了拍我", "tickled my", "tickled me"] # 可通过添加对应语言的拍一拍通知中的关键词适配更多 + if itchat_msg["Type"] == TEXT: self.ctype = ContextType.TEXT self.content = itchat_msg["Text"] @@ -26,22 +31,42 @@ def __init__(self, itchat_msg, is_group=False): self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 self._prepare_fn = lambda: itchat_msg.download(self.content) elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000: - if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]): - self.ctype = ContextType.JOIN_GROUP - self.content = itchat_msg["Content"] + if is_group: + if any(note_bot_join_group in itchat_msg["Content"] for note_bot_join_group in notes_bot_join_group): # 邀请机器人加入群聊 + logger.warn("机器人加入群聊消息,不处理~") + pass + elif any(note_join_group in itchat_msg["Content"] for note_join_group in notes_join_group): # 若有任何在notes_join_group列表中的字符串出现在NOTE中 # 这里只能得到nickname, actual_user_id还是机器人的id - if "加入了群聊" in itchat_msg["Content"]: - self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1] - elif "加入群聊" in itchat_msg["Content"]: + if "加入群聊" not in itchat_msg["Content"]: + self.ctype = ContextType.JOIN_GROUP + self.content = itchat_msg["Content"] + if "invited" in itchat_msg["Content"]: # 匹配英文信息 + self.actual_user_nickname = re.findall(r'invited\s+(.+?)\s+to\s+the\s+group\s+chat', itchat_msg["Content"])[0] + elif "joined" in itchat_msg["Content"]: # 匹配通过二维码加入的英文信息 + self.actual_user_nickname = re.findall(r'"(.*?)" joined the group chat via the QR Code shared by', itchat_msg["Content"])[0] + elif "加入了群聊" in itchat_msg["Content"]: + self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1] + elif "加入群聊" in itchat_msg["Content"]: + self.ctype = ContextType.JOIN_GROUP + self.content = itchat_msg["Content"] + self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] + + elif any(note_exit_group in itchat_msg["Content"] for note_exit_group in notes_exit_group): # 若有任何在notes_exit_group列表中的字符串出现在NOTE中 + self.ctype = ContextType.EXIT_GROUP + self.content = itchat_msg["Content"] self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] + elif "你已添加了" in itchat_msg["Content"]: #通过好友请求 self.ctype = ContextType.ACCEPT_FRIEND self.content = itchat_msg["Content"] - elif "拍了拍我" in itchat_msg["Content"]: + elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中: self.ctype = ContextType.PATPAT self.content = itchat_msg["Content"] if is_group: - self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] + if "拍了拍我" in itchat_msg["Content"]: # 识别中文 + self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] + elif ("tickled my" in itchat_msg["Content"] or "tickled me" in itchat_msg["Content"]): + self.actual_user_nickname = re.findall(r'^(.*?)(?:tickled my|tickled me)', itchat_msg["Content"])[0] else: raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"]) elif itchat_msg["Type"] == ATTACHMENT: @@ -90,5 +115,5 @@ def __init__(self, itchat_msg, is_group=False): if self.is_group: self.is_at = itchat_msg["IsAt"] self.actual_user_id = itchat_msg["ActualUserName"] - if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT]: + if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.EXIT_GROUP]: self.actual_user_nickname = itchat_msg["ActualNickName"] diff --git a/channel/wechatcom/wechatcomapp_channel.py b/channel/wechatcom/wechatcomapp_channel.py index 1a0859690..5ed329681 100644 --- a/channel/wechatcom/wechatcomapp_channel.py +++ b/channel/wechatcom/wechatcomapp_channel.py @@ -17,7 +17,7 @@ from channel.wechatcom.wechatcomapp_message import WechatComAppMessage from common.log import logger from common.singleton import singleton -from common.utils import compress_imgfile, fsize, split_string_by_utf8_length +from common.utils import compress_imgfile, fsize, split_string_by_utf8_length, convert_webp_to_png from config import conf, subscribe_msg from voice.audio_convert import any_to_amr, split_audio @@ -44,7 +44,7 @@ def __init__(self): def startup(self): # start message listener - urls = ("/wxcomapp", "channel.wechatcom.wechatcomapp_channel.Query") + urls = ("/wxcomapp/?", "channel.wechatcom.wechatcomapp_channel.Query") app = web.application(urls, globals(), autoreload=False) port = conf().get("wechatcomapp_port", 9898) web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) @@ -99,6 +99,12 @@ def send(self, reply: Reply, context: Context): image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1) logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage))) image_storage.seek(0) + if ".webp" in img_url: + try: + image_storage = convert_webp_to_png(image_storage) + except Exception as e: + logger.error(f"Failed to convert image: {e}") + return try: response = self.client.media.upload("image", image_storage) logger.debug("[wechatcom] upload image response: {}".format(response)) @@ -156,11 +162,12 @@ def POST(self): logger.debug("[wechatcom] receive message: {}, msg= {}".format(message, msg)) if msg.type == "event": if msg.event == "subscribe": - reply_content = subscribe_msg() - if reply_content: - reply = create_reply(reply_content, msg).render() - res = channel.crypto.encrypt_message(reply, nonce, timestamp) - return res + pass + # reply_content = subscribe_msg() + # if reply_content: + # reply = create_reply(reply_content, msg).render() + # res = channel.crypto.encrypt_message(reply, nonce, timestamp) + # return res else: try: wechatcom_msg = WechatComAppMessage(msg, client=channel.client) diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py index 2ae88228f..fa2397681 100644 --- a/channel/wechatmp/wechatmp_channel.py +++ b/channel/wechatmp/wechatmp_channel.py @@ -140,6 +140,42 @@ def send(self, reply: Reply, context: Context): media_id = response["media_id"] logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id)) self.cache_dict[receiver].append(("image", media_id)) + elif reply.type == ReplyType.VIDEO_URL: # 从网络下载视频 + video_url = reply.content + video_res = requests.get(video_url, stream=True) + video_storage = io.BytesIO() + for block in video_res.iter_content(1024): + video_storage.write(block) + video_storage.seek(0) + video_type = 'mp4' + filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type + content_type = "video/" + video_type + try: + response = self.client.material.add("video", (filename, video_storage, content_type)) + logger.debug("[wechatmp] upload video response: {}".format(response)) + except WeChatClientException as e: + logger.error("[wechatmp] upload video failed: {}".format(e)) + return + media_id = response["media_id"] + logger.info("[wechatmp] video uploaded, receiver {}, media_id {}".format(receiver, media_id)) + self.cache_dict[receiver].append(("video", media_id)) + + elif reply.type == ReplyType.VIDEO: # 从文件读取视频 + video_storage = reply.content + video_storage.seek(0) + video_type = 'mp4' + filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type + content_type = "video/" + video_type + try: + response = self.client.material.add("video", (filename, video_storage, content_type)) + logger.debug("[wechatmp] upload video response: {}".format(response)) + except WeChatClientException as e: + logger.error("[wechatmp] upload video failed: {}".format(e)) + return + media_id = response["media_id"] + logger.info("[wechatmp] video uploaded, receiver {}, media_id {}".format(receiver, media_id)) + self.cache_dict[receiver].append(("video", media_id)) + else: if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR: reply_text = reply.content @@ -222,6 +258,38 @@ def send(self, reply: Reply, context: Context): return self.client.message.send_image(receiver, response["media_id"]) logger.info("[wechatmp] Do send image to {}".format(receiver)) + elif reply.type == ReplyType.VIDEO_URL: # 从网络下载视频 + video_url = reply.content + video_res = requests.get(video_url, stream=True) + video_storage = io.BytesIO() + for block in video_res.iter_content(1024): + video_storage.write(block) + video_storage.seek(0) + video_type = 'mp4' + filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type + content_type = "video/" + video_type + try: + response = self.client.media.upload("video", (filename, video_storage, content_type)) + logger.debug("[wechatmp] upload video response: {}".format(response)) + except WeChatClientException as e: + logger.error("[wechatmp] upload video failed: {}".format(e)) + return + self.client.message.send_video(receiver, response["media_id"]) + logger.info("[wechatmp] Do send video to {}".format(receiver)) + elif reply.type == ReplyType.VIDEO: # 从文件读取视频 + video_storage = reply.content + video_storage.seek(0) + video_type = 'mp4' + filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type + content_type = "video/" + video_type + try: + response = self.client.media.upload("video", (filename, video_storage, content_type)) + logger.debug("[wechatmp] upload video response: {}".format(response)) + except WeChatClientException as e: + logger.error("[wechatmp] upload video failed: {}".format(e)) + return + self.client.message.send_video(receiver, response["media_id"]) + logger.info("[wechatmp] Do send video to {}".format(receiver)) return def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 diff --git a/channel/wework/wework_channel.py b/channel/wework/wework_channel.py index fb7784373..102026105 100644 --- a/channel/wework/wework_channel.py +++ b/channel/wework/wework_channel.py @@ -120,7 +120,7 @@ def wrapper(self, cmsg: ChatMessage): @wework.msg_register( - [ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_VOICE_MSG]) + [ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_LINK_CARD_MSG,ntwork.MT_RECV_FILE_MSG, ntwork.MT_RECV_VOICE_MSG]) def all_msg_handler(wework_instance: ntwork.WeWork, message): logger.debug(f"收到消息: {message}") if 'data' in message: diff --git a/channel/wework/wework_message.py b/channel/wework/wework_message.py index e95dfb194..0d9e96ef7 100644 --- a/channel/wework/wework_message.py +++ b/channel/wework/wework_message.py @@ -64,7 +64,10 @@ def cdn_download(wework, message, file_name): } result = wework._WeWork__send_sync(send_type.MT_WXCDN_DOWNLOAD_MSG, data) # 直接用wx_cdn_download的接口内部实现来调用 elif "file_id" in data["cdn"].keys(): - file_type = 2 + if message["type"] == 11042: + file_type = 2 + elif message["type"] == 11045: + file_type = 5 file_id = data["cdn"]["file_id"] result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path) else: @@ -128,6 +131,18 @@ def __init__(self, wework_msg, wework, is_group=False): self.ctype = ContextType.IMAGE self.content = os.path.join(current_dir, "tmp", file_name) self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name) + elif wework_msg["type"] == 11045: # 文件消息 + print("文件消息") + print(wework_msg) + file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + file_name = file_name + wework_msg['data']['cdn']['file_name'] + current_dir = os.getcwd() + self.ctype = ContextType.FILE + self.content = os.path.join(current_dir, "tmp", file_name) + self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name) + elif wework_msg["type"] == 11047: # 链接消息 + self.ctype = ContextType.SHARING + self.content = wework_msg['data']['url'] elif wework_msg["type"] == 11072: # 新成员入群通知 self.ctype = ContextType.JOIN_GROUP member_list = wework_msg['data']['member_list'] @@ -179,6 +194,7 @@ def __init__(self, wework_msg, wework, is_group=False): if conversation_id: room_info = get_room_info(wework=wework, conversation_id=conversation_id) self.other_user_nickname = room_info.get('nickname', None) if room_info else None + self.from_user_nickname = room_info.get('nickname', None) if room_info else None at_list = data.get('at_list', []) tmp_list = [] for at in at_list: diff --git a/common/const.py b/common/const.py index 6c790f794..e2e6a00e1 100644 --- a/common/const.py +++ b/common/const.py @@ -1,18 +1,74 @@ # bot_type OPEN_AI = "openAI" CHATGPT = "chatGPT" -BAIDU = "baidu" +BAIDU = "baidu" # 百度文心一言模型 XUNFEI = "xunfei" CHATGPTONAZURE = "chatGPTOnAzure" LINKAI = "linkai" -CLAUDEAI = "claude" +CLAUDEAI = "claude" # 使用cookie的历史模型 +CLAUDEAPI= "claudeAPI" # 通过Claude api调用模型 +QWEN = "qwen" # 旧版通义模型 +QWEN_DASHSCOPE = "dashscope" # 通义新版sdk和api key + + +GEMINI = "gemini" # gemini-1.0-pro +ZHIPU_AI = "glm-4" +MOONSHOT = "moonshot" +MiniMax = "minimax" + # model -GPT4 = "gpt-4" -GPT4_TURBO_PREVIEW = "gpt-4-1106-preview" +CLAUDE3 = "claude-3-opus-20240229" +GPT35 = "gpt-3.5-turbo" +GPT35_0125 = "gpt-3.5-turbo-0125" +GPT35_1106 = "gpt-3.5-turbo-1106" + +GPT_4o = "gpt-4o" +GPT_4O_0806 = "gpt-4o-2024-08-06" +GPT4_TURBO = "gpt-4-turbo" +GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview" +GPT4_TURBO_04_09 = "gpt-4-turbo-2024-04-09" +GPT4_TURBO_01_25 = "gpt-4-0125-preview" +GPT4_TURBO_11_06 = "gpt-4-1106-preview" GPT4_VISION_PREVIEW = "gpt-4-vision-preview" + +GPT4 = "gpt-4" +GPT_4o_MINI = "gpt-4o-mini" +GPT4_32k = "gpt-4-32k" +GPT4_06_13 = "gpt-4-0613" +GPT4_32k_06_13 = "gpt-4-32k-0613" + WHISPER_1 = "whisper-1" TTS_1 = "tts-1" TTS_1_HD = "tts-1-hd" -MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW] +WEN_XIN = "wenxin" +WEN_XIN_4 = "wenxin-4" + +QWEN_TURBO = "qwen-turbo" +QWEN_PLUS = "qwen-plus" +QWEN_MAX = "qwen-max" + +LINKAI_35 = "linkai-3.5" +LINKAI_4_TURBO = "linkai-4-turbo" +LINKAI_4o = "linkai-4o" + +GEMINI_PRO = "gemini-1.0-pro" +GEMINI_15_flash = "gemini-1.5-flash" +GEMINI_15_PRO = "gemini-1.5-pro" + +MODEL_LIST = [ + GPT35, GPT35_0125, GPT35_1106, "gpt-3.5-turbo-16k", + GPT_4o, GPT_4O_0806, GPT_4o_MINI, GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4, GPT4_32k, GPT4_06_13, GPT4_32k_06_13, + WEN_XIN, WEN_XIN_4, + XUNFEI, ZHIPU_AI, MOONSHOT, MiniMax, + GEMINI, GEMINI_PRO, GEMINI_15_flash, GEMINI_15_PRO, + "claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3-opus-20240229", "claude-3.5-sonnet", + "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k", + QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX, + LINKAI_35, LINKAI_4_TURBO, LINKAI_4o + ] + +# channel +FEISHU = "feishu" +DINGTALK = "dingtalk" diff --git a/common/linkai_client.py b/common/linkai_client.py new file mode 100644 index 000000000..0a922f34d --- /dev/null +++ b/common/linkai_client.py @@ -0,0 +1,105 @@ +from bridge.context import Context, ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from linkai import LinkAIClient, PushMsg +from config import conf, pconf, plugin_config, available_setting +from plugins import PluginManager +import time + + +chat_client: LinkAIClient + + +class ChatClient(LinkAIClient): + def __init__(self, api_key, host, channel): + super().__init__(api_key, host) + self.channel = channel + self.client_type = channel.channel_type + + def on_message(self, push_msg: PushMsg): + session_id = push_msg.session_id + msg_content = push_msg.msg_content + logger.info(f"receive msg push, session_id={session_id}, msg_content={msg_content}") + context = Context() + context.type = ContextType.TEXT + context["receiver"] = session_id + context["isgroup"] = push_msg.is_group + self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context) + + def on_config(self, config: dict): + if not self.client_id: + return + logger.info(f"[LinkAI] 从客户端管理加载远程配置: {config}") + if config.get("enabled") != "Y": + return + + local_config = conf() + for key in config.keys(): + if key in available_setting and config.get(key) is not None: + local_config[key] = config.get(key) + # 语音配置 + reply_voice_mode = config.get("reply_voice_mode") + if reply_voice_mode: + if reply_voice_mode == "voice_reply_voice": + local_config["voice_reply_voice"] = True + elif reply_voice_mode == "always_reply_voice": + local_config["always_reply_voice"] = True + + if config.get("admin_password"): + if not plugin_config.get("Godcmd"): + plugin_config["Godcmd"] = {"password": config.get("admin_password"), "admin_users": []} + else: + plugin_config["Godcmd"]["password"] = config.get("admin_password") + PluginManager().instances["GODCMD"].reload() + + if config.get("group_app_map") and pconf("linkai"): + local_group_map = {} + for mapping in config.get("group_app_map"): + local_group_map[mapping.get("group_name")] = mapping.get("app_code") + pconf("linkai")["group_app_map"] = local_group_map + PluginManager().instances["LINKAI"].reload() + + if config.get("text_to_image") and config.get("text_to_image") == "midjourney" and pconf("linkai"): + if pconf("linkai")["midjourney"]: + pconf("linkai")["midjourney"]["enabled"] = True + pconf("linkai")["midjourney"]["use_image_create_prefix"] = True + elif config.get("text_to_image") and config.get("text_to_image") in ["dall-e-2", "dall-e-3"]: + if pconf("linkai")["midjourney"]: + pconf("linkai")["midjourney"]["use_image_create_prefix"] = False + + +def start(channel): + global chat_client + chat_client = ChatClient(api_key=conf().get("linkai_api_key"), host="", channel=channel) + chat_client.config = _build_config() + chat_client.start() + time.sleep(1.5) + if chat_client.client_id: + logger.info("[LinkAI] 可前往控制台进行线上登录和配置:https://link-ai.tech/console/clients") + + +def _build_config(): + local_conf = conf() + config = { + "linkai_app_code": local_conf.get("linkai_app_code"), + "single_chat_prefix": local_conf.get("single_chat_prefix"), + "single_chat_reply_prefix": local_conf.get("single_chat_reply_prefix"), + "single_chat_reply_suffix": local_conf.get("single_chat_reply_suffix"), + "group_chat_prefix": local_conf.get("group_chat_prefix"), + "group_chat_reply_prefix": local_conf.get("group_chat_reply_prefix"), + "group_chat_reply_suffix": local_conf.get("group_chat_reply_suffix"), + "group_name_white_list": local_conf.get("group_name_white_list"), + "nick_name_black_list": local_conf.get("nick_name_black_list"), + "speech_recognition": "Y" if local_conf.get("speech_recognition") else "N", + "text_to_image": local_conf.get("text_to_image"), + "image_create_prefix": local_conf.get("image_create_prefix") + } + if local_conf.get("always_reply_voice"): + config["reply_voice_mode"] = "always_reply_voice" + elif local_conf.get("voice_reply_voice"): + config["reply_voice_mode"] = "voice_reply_voice" + if pconf("linkai"): + config["group_app_map"] = pconf("linkai").get("group_app_map") + if plugin_config.get("Godcmd"): + config["admin_password"] = plugin_config.get("Godcmd").get("password") + return config diff --git a/common/memory.py b/common/memory.py new file mode 100644 index 000000000..026bed2c8 --- /dev/null +++ b/common/memory.py @@ -0,0 +1,3 @@ +from common.expired_dict import ExpiredDict + +USER_IMAGE_CACHE = ExpiredDict(60 * 3) \ No newline at end of file diff --git a/common/time_check.py b/common/time_check.py index 5c2dacba6..15be2895e 100644 --- a/common/time_check.py +++ b/common/time_check.py @@ -1,7 +1,5 @@ -import hashlib import re import time - import config from common.log import logger @@ -10,31 +8,33 @@ def time_checker(f): def _time_checker(self, *args, **kwargs): _config = config.conf() chat_time_module = _config.get("chat_time_module", False) + if chat_time_module: chat_start_time = _config.get("chat_start_time", "00:00") - chat_stopt_time = _config.get("chat_stop_time", "24:00") - time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00 + chat_stop_time = _config.get("chat_stop_time", "24:00") - starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 - stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 - chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 + time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") - # 时间格式检查 - if not (starttime_format_check and stoptime_format_check and chat_time_check): - logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check)) - if chat_start_time > "23:59": - logger.error("启动时间可能存在问题,请修改!") + if not (time_regex.match(chat_start_time) and time_regex.match(chat_stop_time)): + logger.warning("时间格式不正确,请在config.json中修改CHAT_START_TIME/CHAT_STOP_TIME。") + return None - # 服务时间检查 - now_time = time.strftime("%H:%M", time.localtime()) - if chat_start_time <= now_time <= chat_stopt_time: # 服务时间内,正常返回回答 + now_time = time.strptime(time.strftime("%H:%M"), "%H:%M") + chat_start_time = time.strptime(chat_start_time, "%H:%M") + chat_stop_time = time.strptime(chat_stop_time, "%H:%M") + # 结束时间小于开始时间,跨天了 + if chat_stop_time < chat_start_time and (chat_start_time <= now_time or now_time <= chat_stop_time): + f(self, *args, **kwargs) + # 结束大于开始时间代表,没有跨天 + elif chat_start_time < chat_stop_time and chat_start_time <= now_time <= chat_stop_time: f(self, *args, **kwargs) - return None else: - if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置 + # 定义匹配规则,如果以 #reconf 或者 #更新配置 结尾, 非服务时间可以修改开始/结束时间并重载配置 + pattern = re.compile(r"^.*#(?:reconf|更新配置)$") + if args and pattern.match(args[0].content): f(self, *args, **kwargs) else: - logger.info("非服务时间内,不接受访问") + logger.info("非服务时间内,不接受访问") return None else: f(self, *args, **kwargs) # 未开启时间模块则直接回答 diff --git a/common/utils.py b/common/utils.py index 966a7cf1f..2349898e4 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,8 +1,8 @@ import io import os - +from urllib.parse import urlparse from PIL import Image - +from common.log import logger def fsize(file): if isinstance(file, io.BytesIO): @@ -49,3 +49,22 @@ def split_string_by_utf8_length(string, max_length, max_split=0): result.append(encoded[start:end].decode("utf-8")) start = end return result + + +def get_path_suffix(path): + path = urlparse(path).path + return os.path.splitext(path)[-1].lstrip('.') + + +def convert_webp_to_png(webp_image): + from PIL import Image + try: + webp_image.seek(0) + img = Image.open(webp_image).convert("RGBA") + png_image = io.BytesIO() + img.save(png_image, format="PNG") + png_image.seek(0) + return png_image + except Exception as e: + logger.error(f"Failed to convert WEBP to PNG: {e}") + raise diff --git a/config-template.json b/config-template.json index f18f83b43..d0268d3b1 100644 --- a/config-template.json +++ b/config-template.json @@ -1,7 +1,8 @@ { "channel_type": "wx", + "model": "", "open_ai_api_key": "YOUR API KEY", - "model": "gpt-3.5-turbo", + "claude_api_key": "YOUR API KEY", "text_to_image": "dall-e-2", "voice_to_text": "openai", "text_to_voice": "openai", @@ -19,22 +20,17 @@ "ChatGPT测试群", "ChatGPT测试群2" ], - "group_chat_in_one_session": [ - "ChatGPT测试群" - ], "image_create_prefix": [ "画" ], "speech_recognition": true, "group_speech_recognition": false, "voice_reply_voice": false, - "tts_voice_id": "alloy", - "conversation_max_tokens": 1000, + "conversation_max_tokens": 2500, "expires_in_seconds": 3600, - "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", + "character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", "temperature": 0.7, - "top_p": 1, - "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。", + "subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。", "use_linkai": false, "linkai_api_key": "", "linkai_app_code": "" diff --git a/config.py b/config.py index 53d8bf043..a28c5b959 100644 --- a/config.py +++ b/config.py @@ -4,6 +4,7 @@ import logging import os import pickle +import copy from common.log import logger @@ -16,7 +17,8 @@ "open_ai_api_base": "https://api.openai.com/v1", "proxy": "", # openai使用的代理 # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 - "model": "gpt-3.5-turbo", # 还支持 gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei + "model": "gpt-3.5-turbo", # 可选择: gpt-4o, pt-4o-mini, gpt-4-turbo, claude-3-sonnet, wenxin, moonshot, qwen-turbo, xunfei, glm-4, minimax, gemini等模型,全部可选模型详见common/const.py文件 + "bot_type": "", # 可选配置,使用兼容openai格式的三方服务时候,需填"chatGPT"。bot具体名称详见common/const.py文件列出的bot_type,如不填根据model名称判断, "use_azure_chatgpt": False, # 是否使用azure的chatgpt "azure_deployment_id": "", # azure 模型部署名称 "azure_api_version": "", # azure api版本 @@ -25,6 +27,7 @@ "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 "single_chat_reply_suffix": "", # 私聊时自动回复的后缀,\n 可以换行 "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 + "no_need_at": False, # 群聊回复时是否不需要艾特 "group_chat_reply_prefix": "", # 群聊时自动回复的前缀 "group_chat_reply_suffix": "", # 群聊时自动回复的后缀,\n 可以换行 "group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复 @@ -32,13 +35,22 @@ "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 "group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表 "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 - "group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎 + "nick_name_black_list": [], # 用户昵称黑名单 + "group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎 "trigger_by_self": False, # 是否允许机器人触发 "text_to_image": "dall-e-2", # 图片生成模型,可选 dall-e-2, dall-e-3 + # Azure OpenAI dall-e-3 配置 + "dalle3_image_style": "vivid", # 图片生成dalle3的风格,可选有 vivid, natural + "dalle3_image_quality": "hd", # 图片生成dalle3的质量,可选有 standard, hd + # Azure OpenAI DALL-E API 配置, 当use_azure_chatgpt为true时,用于将文字回复的资源和Dall-E的资源分开. + "azure_openai_dalle_api_base": "", # [可选] azure openai 用于回复图片的资源 endpoint,默认使用 open_ai_api_base + "azure_openai_dalle_api_key": "", # [可选] azure openai 用于回复图片的资源 key,默认使用 open_ai_api_key + "azure_openai_dalle_deployment_id":"", # [可选] azure openai 用于回复图片的资源 deployment id,默认使用 text_to_image "image_proxy": True, # 是否需要图片代理,国内访问LinkAI时需要 "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 "concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序 "image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 (dall-e-3默认为1024x1024) + "group_chat_exit_group": False, # chatgpt会话参数 "expires_in_seconds": 3600, # 无操作会话的过期时间 # 人格描述 @@ -52,19 +64,34 @@ "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, - "request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 + "request_timeout": 180, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 "timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 # Baidu 文心一言参数 "baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型 "baidu_wenxin_api_key": "", # Baidu api key "baidu_wenxin_secret_key": "", # Baidu secret key + "baidu_wenxin_prompt_enabled": False, # Enable prompt if you are using ernie character model # 讯飞星火API "xunfei_app_id": "", # 讯飞应用ID "xunfei_api_key": "", # 讯飞 API key "xunfei_api_secret": "", # 讯飞 API secret + "xunfei_domain": "", # 讯飞模型对应的domain参数,Spark4.0 Ultra为 4.0Ultra,其他模型详见: https://www.xfyun.cn/doc/spark/Web.html + "xunfei_spark_url": "", # 讯飞模型对应的请求地址,Spark4.0 Ultra为 wss://spark-api.xf-yun.com/v4.0/chat,其他模型参考详见: https://www.xfyun.cn/doc/spark/Web.html # claude 配置 "claude_api_cookie": "", "claude_uuid": "", + # claude api key + "claude_api_key": "", + # 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html + "qwen_access_key_id": "", + "qwen_access_key_secret": "", + "qwen_agent_key": "", + "qwen_app_id": "", + "qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串 + # 阿里灵积(通义新版sdk)模型api key + "dashscope_api_key": "", + # Google Gemini Api Key + "gemini_api_key": "", # wework的通用配置 "wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开 # 语音设置 @@ -72,8 +99,8 @@ "group_speech_recognition": False, # 是否开启群组语音识别 "voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key "always_reply_voice": False, # 是否一直使用语音回复 - "voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure - "text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,pytts(offline),azure,elevenlabs + "voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure,xunfei,ali + "text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,azure,xunfei,ali,pytts(offline),elevenlabs,edge(online) "text_to_voice_model": "tts-1", "tts_voice_id": "alloy", # baidu 语音api配置, 使用百度语音识别和语音合成时需要 @@ -81,13 +108,13 @@ "baidu_api_key": "", "baidu_secret_key": "", # 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场 - "baidu_dev_pid": "1536", + "baidu_dev_pid": 1536, # azure 语音api配置, 使用azure语音识别和语音合成时需要 "azure_voice_api_key": "", "azure_voice_region": "japaneast", # elevenlabs 语音api配置 - "xi_api_key": "", #获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication - "xi_voice_id": "", #ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam” + "xi_api_key": "", # 获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication + "xi_voice_id": "", # ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam” # 服务时间限制,目前支持itchat "chat_time_module": False, # 是否开启服务时间限制 "chat_start_time": "00:00", # 服务开始时间 @@ -115,10 +142,21 @@ "wechatcomapp_secret": "", # 企业微信app的secret "wechatcomapp_agent_id": "", # 企业微信app的agent_id "wechatcomapp_aes_key": "", # 企业微信app的aes_key + # 飞书配置 + "feishu_port": 80, # 飞书bot监听端口 + "feishu_app_id": "", # 飞书机器人应用APP Id + "feishu_app_secret": "", # 飞书机器人APP secret + "feishu_token": "", # 飞书 verification token + "feishu_bot_name": "", # 飞书机器人的名字 + # 钉钉配置 + "dingtalk_client_id": "", # 钉钉机器人Client ID + "dingtalk_client_secret": "", # 钉钉机器人Client Secret + "dingtalk_card_enabled": False, + # chatgpt指令自定义触发词 "clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头 # channel配置 - "channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app} + "channel_type": "", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app,dingtalk} "subscribe_msg": "", # 订阅消息, 支持: wechatmp, wechatmp_service, wechatcom_app "debug": False, # 是否开启debug模式,开启后会打印更多日志 "appdata_dir": "", # 数据目录 @@ -126,11 +164,21 @@ "plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 # 是否使用全局插件配置 "use_global_plugin_config": False, - # 知识库平台配置 + "max_media_send_count": 3, # 单次最大发送媒体资源的个数 + "media_send_interval": 1, # 发送图片的事件间隔,单位秒 + # 智谱AI 平台配置 + "zhipu_ai_api_key": "", + "zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4", + "moonshot_api_key": "", + "moonshot_base_url": "https://api.moonshot.cn/v1/chat/completions", + # LinkAI平台配置 "use_linkai": False, "linkai_api_key": "", "linkai_app_code": "", - "linkai_api_base": "https://api.link-ai.chat", # linkAI服务地址,若国内无法访问或延迟较高可改为 https://api.link-ai.tech + "linkai_api_base": "https://api.link-ai.tech", # linkAI服务地址 + "Minimax_api_key": "", + "Minimax_group_id": "", + "Minimax_base_url": "", } @@ -191,6 +239,30 @@ def save_user_datas(self): config = Config() +def drag_sensitive(config): + try: + if isinstance(config, str): + conf_dict: dict = json.loads(config) + conf_dict_copy = copy.deepcopy(conf_dict) + for key in conf_dict_copy: + if "key" in key or "secret" in key: + if isinstance(conf_dict_copy[key], str): + conf_dict_copy[key] = conf_dict_copy[key][0:3] + "*" * 5 + conf_dict_copy[key][-3:] + return json.dumps(conf_dict_copy, indent=4) + + elif isinstance(config, dict): + config_copy = copy.deepcopy(config) + for key in config: + if "key" in key or "secret" in key: + if isinstance(config_copy[key], str): + config_copy[key] = config_copy[key][0:3] + "*" * 5 + config_copy[key][-3:] + return config_copy + except Exception as e: + logger.exception(e) + return config + return config + + def load_config(): global config config_path = "./config.json" @@ -199,7 +271,7 @@ def load_config(): config_path = "./config-template.json" config_str = read_file(config_path) - logger.debug("[INIT] config str: {}".format(config_str)) + logger.debug("[INIT] config str: {}".format(drag_sensitive(config_str))) # 将json字符串反序列化为dict类型 config = Config(json.loads(config_str)) @@ -224,7 +296,7 @@ def load_config(): logger.setLevel(logging.DEBUG) logger.debug("[INIT] set log level to DEBUG") - logger.info("[INIT] load config: {}".format(config)) + logger.info("[INIT] load config: {}".format(drag_sensitive(config))) config.load_user_datas() @@ -280,6 +352,4 @@ def pconf(plugin_name: str) -> dict: # 全局配置,用于存放全局生效的状态 -global_config = { - "admin_users": [] -} +global_config = {"admin_users": []} diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 8dbb1e41e..39fdffb15 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -6,6 +6,7 @@ services: security_opt: - seccomp:unconfined environment: + TZ: 'Asia/Shanghai' OPEN_AI_API_KEY: 'YOUR API KEY' MODEL: 'gpt-3.5-turbo' PROXY: '' diff --git a/docs/images/aigcopen.png b/docs/images/aigcopen.png deleted file mode 100644 index 76a20c620..000000000 Binary files a/docs/images/aigcopen.png and /dev/null differ diff --git a/docs/images/group-chat-sample.jpg b/docs/images/group-chat-sample.jpg deleted file mode 100644 index 35fffdada..000000000 Binary files a/docs/images/group-chat-sample.jpg and /dev/null differ diff --git a/docs/images/image-create-sample.jpg b/docs/images/image-create-sample.jpg deleted file mode 100644 index 5d916c573..000000000 Binary files a/docs/images/image-create-sample.jpg and /dev/null differ diff --git a/docs/images/planet.jpg b/docs/images/planet.jpg deleted file mode 100644 index dffca7f25..000000000 Binary files a/docs/images/planet.jpg and /dev/null differ diff --git a/docs/images/single-chat-sample.jpg b/docs/images/single-chat-sample.jpg deleted file mode 100644 index f24b74d60..000000000 Binary files a/docs/images/single-chat-sample.jpg and /dev/null differ diff --git a/docs/version/old-version.md b/docs/version/old-version.md new file mode 100644 index 000000000..a7b6240fd --- /dev/null +++ b/docs/version/old-version.md @@ -0,0 +1,13 @@ +## 归档更新日志 + +2023.04.26: 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,使用文档。(contributed by @lanvent in #944) + +2023.04.05: 支持微信公众号部署,兼容插件,并支持语音图片交互,使用文档。(contributed by @JS00000 in #686) + +2023.04.05: 增加能让ChatGPT使用工具的tool插件,使用文档。工具相关issue可反馈至chatgpt-tool-hub。(contributed by @goldfishh in #663) + +2023.03.25: 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 #578。(contributed by @lanvent in #565) + +2023.03.09: 基于 whisper API(后续已接入更多的语音API服务) 实现对语音消息的解析和回复,添加配置项 "speech_recognition":true 即可启用,使用参考 #415。(contributed by wanggang1987 in #385) + +2023.02.09: 扫码登录存在账号限制风险,请谨慎使用,参考#58 \ No newline at end of file diff --git a/lib/itchat/LICENSE b/lib/itchat/LICENSE new file mode 100644 index 000000000..ba1a0e273 --- /dev/null +++ b/lib/itchat/LICENSE @@ -0,0 +1,9 @@ +**The MIT License (MIT)** + +Copyright (c) 2017 LittleCoder ([littlecodersh@Github](https://github.com/littlecodersh)) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/plugins/config.json.template b/plugins/config.json.template index 3334a625c..f11fb9391 100644 --- a/plugins/config.json.template +++ b/plugins/config.json.template @@ -10,9 +10,7 @@ }, "tool": { "tools": [ - "python", "url-get", - "terminal", "meteo-weather" ], "kwargs": { @@ -33,6 +31,29 @@ "max_tasks": 3, "max_tasks_per_user": 1, "use_image_create_prefix": true + }, + "summary": { + "enabled": true, + "group_enabled": true, + "max_file_size": 5000, + "type": ["FILE", "SHARING"] } + }, + "hello": { + "group_welc_fixed_msg": { + "群聊1": "群聊1的固定欢迎语", + "群聊2": "群聊2的固定欢迎语" + }, + "group_welc_prompt": "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。", + + "group_exit_prompt": "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。", + + "patpat_prompt": "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。", + + "use_character_desc": false + }, + "Apilot": { + "alapi_token": "xxx", + "morning_news_text_enabled": false } } diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index e1c47c803..d6383bfa8 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -266,14 +266,16 @@ def on_handle_context(self, e_context: EventContext): if not isadmin and not self.is_admin_in_group(e_context["context"]): ok, result = False, "需要管理员权限执行" elif len(args) == 0: - ok, result = True, "当前模型为: " + str(conf().get("model")) + model = conf().get("model") or const.GPT35 + ok, result = True, "当前模型为: " + str(model) elif len(args) == 1: if args[0] not in const.MODEL_LIST: ok, result = False, "模型名称不存在" else: conf()["model"] = self.model_mapping(args[0]) Bridge().reset_bot() - ok, result = True, "模型设置为: " + str(conf().get("model")) + model = conf().get("model") or const.GPT35 + ok, result = True, "模型设置为: " + str(model) elif cmd == "id": ok, result = True, user elif cmd == "set_openai_api_key": @@ -311,7 +313,7 @@ def on_handle_context(self, e_context: EventContext): except Exception as e: ok, result = False, "你没有设置私有GPT模型" elif cmd == "reset": - if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI]: + if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI]: bot.sessions.clear_session(session_id) if Bridge().chat_bots.get(bottype): Bridge().chat_bots.get(bottype).sessions.clear_session(session_id) @@ -337,7 +339,7 @@ def on_handle_context(self, e_context: EventContext): ok, result = True, "配置已重载" elif cmd == "resetall": if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, - const.BAIDU, const.XUNFEI]: + const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT]: channel.cancel_all_session() bot.sessions.clear_all_session() ok, result = True, "重置所有会话成功" @@ -473,3 +475,11 @@ def model_mapping(self, model) -> str: if model == "gpt-4-turbo": return const.GPT4_TURBO_PREVIEW return model + + def reload(self): + gconf = plugin_config[self.name] + if gconf: + if gconf.get("password"): + self.password = gconf["password"] + if gconf.get("admin_users"): + self.admin_users = gconf["admin_users"] diff --git a/plugins/hello/README.md b/plugins/hello/README.md new file mode 100644 index 000000000..fb560c441 --- /dev/null +++ b/plugins/hello/README.md @@ -0,0 +1,41 @@ +## 插件说明 + +可以根据需求设置入群欢迎、群聊拍一拍、退群等消息的自定义提示词,也支持为每个群设置对应的固定欢迎语。 + +该插件也是用户根据需求开发自定义插件的示例插件,参考[插件开发说明](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins) + +## 插件配置 + +将 `plugins/hello` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`。 (如果未配置则会默认使用`config.json.template`模板中配置)。 + +以下是插件配置项说明: + +```bash +{ + "group_welc_fixed_msg": { ## 这里可以为特定群里配置特定的固定欢迎语 + "群聊1": "群聊1的固定欢迎语", + "群聊2": "群聊2的固定欢迎语" + }, + + "group_welc_prompt": "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。", ## 群聊随机欢迎语的提示词 + + "group_exit_prompt": "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。", ## 移出群聊的提示词 + + "patpat_prompt": "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。", ## 群内拍一拍的提示词 + + "use_character_desc": false ## 是否在Hello插件中使用LinkAI应用的系统设定 +} +``` + + +注意: + + - 设置全局的用户进群固定欢迎语,可以在***项目根目录下***的`config.json`文件里,可以添加参数`"group_welcome_msg": "" `,参考 [#1482](https://github.com/zhayujie/chatgpt-on-wechat/pull/1482) + - 为每个群设置固定的欢迎语,可以在`"group_welc_fixed_msg": {}`配置群聊名和对应的固定欢迎语,优先级高于全局固定欢迎语 + - 如果没有配置以上两个参数,则使用随机欢迎语,如需设定风格,语言等,修改`"group_welc_prompt": `即可 + - 如果使用LinkAI的服务,想在随机欢迎中结合LinkAI应用的设定,配置`"use_character_desc": true ` + - 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释 + - 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8) + + + diff --git a/plugins/hello/config.json.template b/plugins/hello/config.json.template new file mode 100644 index 000000000..13c3788a5 --- /dev/null +++ b/plugins/hello/config.json.template @@ -0,0 +1,14 @@ +{ + "group_welc_fixed_msg": { + "群聊1": "群聊1的固定欢迎语", + "群聊2": "群聊2的固定欢迎语" + }, + + "group_welc_prompt": "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。", + + "group_exit_prompt": "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。", + + "patpat_prompt": "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。", + + "use_character_desc": false +} \ No newline at end of file diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index dcc248f6d..23de86166 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -17,39 +17,73 @@ version="0.1", author="lanvent", ) + + class Hello(Plugin): + + group_welc_prompt = "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。" + group_exit_prompt = "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。" + patpat_prompt = "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。" + def __init__(self): super().__init__() - self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context - logger.info("[Hello] inited") + try: + self.config = super().load_config() + if not self.config: + self.config = self._load_config_template() + self.group_welc_fixed_msg = self.config.get("group_welc_fixed_msg", {}) + self.group_welc_prompt = self.config.get("group_welc_prompt", self.group_welc_prompt) + self.group_exit_prompt = self.config.get("group_exit_prompt", self.group_exit_prompt) + self.patpat_prompt = self.config.get("patpat_prompt", self.patpat_prompt) + logger.info("[Hello] inited") + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + except Exception as e: + logger.error(f"[Hello]初始化异常:{e}") + raise "[Hello] init failed, ignore " def on_handle_context(self, e_context: EventContext): if e_context["context"].type not in [ ContextType.TEXT, ContextType.JOIN_GROUP, ContextType.PATPAT, + ContextType.EXIT_GROUP ]: return - + msg: ChatMessage = e_context["context"]["msg"] + group_name = msg.from_user_nickname if e_context["context"].type == ContextType.JOIN_GROUP: - if "group_welcome_msg" in conf(): + if "group_welcome_msg" in conf() or group_name in self.group_welc_fixed_msg: reply = Reply() reply.type = ReplyType.TEXT - reply.content = conf().get("group_welcome_msg", "") + if group_name in self.group_welc_fixed_msg: + reply.content = self.group_welc_fixed_msg.get(group_name, "") + else: + reply.content = conf().get("group_welcome_msg", "") e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 return e_context["context"].type = ContextType.TEXT - msg: ChatMessage = e_context["context"]["msg"] - e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。' + e_context["context"].content = self.group_welc_prompt.format(nickname=msg.actual_user_nickname) e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑 + if not self.config or not self.config.get("use_character_desc"): + e_context["context"]["generate_breaked_by"] = EventAction.BREAK return - + + if e_context["context"].type == ContextType.EXIT_GROUP: + if conf().get("group_chat_exit_group"): + e_context["context"].type = ContextType.TEXT + e_context["context"].content = self.group_exit_prompt.format(nickname=msg.actual_user_nickname) + e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑 + return + e_context.action = EventAction.BREAK + return + if e_context["context"].type == ContextType.PATPAT: e_context["context"].type = ContextType.TEXT - msg: ChatMessage = e_context["context"]["msg"] - e_context["context"].content = f"请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。" + e_context["context"].content = self.patpat_prompt e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑 + if not self.config or not self.config.get("use_character_desc"): + e_context["context"]["generate_breaked_by"] = EventAction.BREAK return content = e_context["context"].content @@ -57,7 +91,6 @@ def on_handle_context(self, e_context: EventContext): if content == "Hello": reply = Reply() reply.type = ReplyType.TEXT - msg: ChatMessage = e_context["context"]["msg"] if e_context["context"]["isgroup"]: reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" else: @@ -81,3 +114,14 @@ def on_handle_context(self, e_context: EventContext): def get_help_text(self, **kwargs): help_text = "输入Hello,我会回复你的名字\n输入End,我会回复你世界的图片\n" return help_text + + def _load_config_template(self): + logger.debug("No Hello plugin config.json, use plugins/hello/config.json.template") + try: + plugin_config_path = os.path.join(self.path, "config.json.template") + if os.path.exists(plugin_config_path): + with open(plugin_config_path, "r", encoding="utf-8") as f: + plugin_conf = json.load(f) + return plugin_conf + except Exception as e: + logger.exception(e) \ No newline at end of file diff --git a/plugins/keyword/keyword.py b/plugins/keyword/keyword.py index 87cd05435..281b8af8e 100644 --- a/plugins/keyword/keyword.py +++ b/plugins/keyword/keyword.py @@ -55,7 +55,7 @@ def on_handle_context(self, e_context: EventContext): reply_text = self.keyword[content] # 判断匹配内容的类型 - if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".img"]): + if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".webp", ".jpeg", ".png", ".gif", ".img"]): # 如果是以 http:// 或 https:// 开头,且".jpg", ".jpeg", ".png", ".gif", ".img"结尾,则认为是图片 URL。 reply = Reply() reply.type = ReplyType.IMAGE_URL diff --git a/plugins/linkai/README.md b/plugins/linkai/README.md index dc3caa776..2ac80b113 100644 --- a/plugins/linkai/README.md +++ b/plugins/linkai/README.md @@ -25,7 +25,8 @@ "summary": { "enabled": true, # 文档总结和对话功能开关 "group_enabled": true, # 是否支持群聊开启 - "max_file_size": 5000 # 文件的大小限制,单位KB,默认为5M,超过该大小直接忽略 + "max_file_size": 5000, # 文件的大小限制,单位KB,默认为5M,超过该大小直接忽略 + "type": ["FILE", "SHARING", "IMAGE"] # 支持总结的类型,分别表示 文件、分享链接、图片,其中文件和链接默认打开,图片默认关闭 } } ``` @@ -99,7 +100,7 @@ #### 使用 -功能开启后,向机器人发送 **文件** 或 **分享链接卡片** 即可生成摘要,进一步可以与文件或链接的内容进行多轮对话。 +功能开启后,向机器人发送 **文件**、 **分享链接卡片**、**图片** 即可生成摘要,进一步可以与文件或链接的内容进行多轮对话。如果需要关闭某种类型的内容总结,设置 `summary`配置中的type字段即可。 #### 限制 diff --git a/plugins/linkai/config.json.template b/plugins/linkai/config.json.template index b6c7a0425..547b8ef8c 100644 --- a/plugins/linkai/config.json.template +++ b/plugins/linkai/config.json.template @@ -15,6 +15,6 @@ "enabled": true, "group_enabled": true, "max_file_size": 5000, - "type": ["FILE", "SHARING", "IMAGE"] + "type": ["FILE", "SHARING"] } } diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index 9e35bcd53..c38319a51 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -9,6 +9,8 @@ from common import const import os from .utils import Util +from config import plugin_config + @plugins.register( name="linkai", @@ -32,7 +34,6 @@ def __init__(self): self.sum_config = self.config.get("summary") logger.info(f"[LinkAI] inited, config={self.config}") - def on_handle_context(self, e_context: EventContext): """ 消息处理逻辑 @@ -42,7 +43,8 @@ def on_handle_context(self, e_context: EventContext): return context = e_context['context'] - if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE, ContextType.FILE, ContextType.SHARING]: + if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE, ContextType.FILE, + ContextType.SHARING]: # filter content no need solve return @@ -68,7 +70,7 @@ def on_handle_context(self, e_context: EventContext): return if (context.type == ContextType.SHARING and self._is_summary_open(context)) or \ - (context.type == ContextType.TEXT and LinkSummary().check_url(context.content)): + (context.type == ContextType.TEXT and self._is_summary_open(context) and LinkSummary().check_url(context.content)): if not LinkSummary().check_url(context.content): return _send_info(e_context, "正在为你加速生成摘要,请稍后") @@ -76,7 +78,8 @@ def on_handle_context(self, e_context: EventContext): if not res: _set_reply_text("因为神秘力量无法获取文章内容,请稍后再试吧~", e_context, level=ReplyType.TEXT) return - _set_reply_text(res.get("summary") + "\n\n💬 发送 \"开启对话\" 可以开启与文章内容的对话", e_context, level=ReplyType.TEXT) + _set_reply_text(res.get("summary") + "\n\n💬 发送 \"开启对话\" 可以开启与文章内容的对话", e_context, + level=ReplyType.TEXT) USER_FILE_MAP[_find_user_id(context) + "-sum_id"] = res.get("summary_id") return @@ -99,7 +102,8 @@ def on_handle_context(self, e_context: EventContext): _set_reply_text("开启对话失败,请稍后再试吧", e_context) return USER_FILE_MAP[_find_user_id(context) + "-file_id"] = res.get("file_id") - _set_reply_text("💡你可以问我关于这篇文章的任何问题,例如:\n\n" + res.get("questions") + "\n\n发送 \"退出对话\" 可以关闭与文章的对话", e_context, level=ReplyType.TEXT) + _set_reply_text("💡你可以问我关于这篇文章的任何问题,例如:\n\n" + res.get( + "questions") + "\n\n发送 \"退出对话\" 可以关闭与文章的对话", e_context, level=ReplyType.TEXT) return if context.type == ContextType.TEXT and context.content == "退出对话" and _find_file_id(context): @@ -117,12 +121,10 @@ def on_handle_context(self, e_context: EventContext): e_context.action = EventAction.BREAK_PASS return - if self._is_chat_task(e_context): # 文本对话任务处理 self._process_chat_task(e_context) - # 插件管理功能 def _process_admin_cmd(self, e_context: EventContext): context = e_context['context'] @@ -177,7 +179,9 @@ def _process_admin_cmd(self, e_context: EventContext): tips_text = "关闭" is_open = False if not self.sum_config: - _set_reply_text(f"插件未启用summary功能,请参考以下链添加插件配置\n\nhttps://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/linkai/README.md", e_context, level=ReplyType.INFO) + _set_reply_text( + f"插件未启用summary功能,请参考以下链添加插件配置\n\nhttps://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/linkai/README.md", + e_context, level=ReplyType.INFO) else: self.sum_config["enabled"] = is_open _set_reply_text(f"文章总结功能{tips_text}", e_context, level=ReplyType.INFO) @@ -192,10 +196,8 @@ def _is_summary_open(self, context) -> bool: return False if context.kwargs.get("isgroup") and not self.sum_config.get("group_enabled"): return False - support_type = self.sum_config.get("type") - if not support_type: - return True - if context.type.name not in support_type: + support_type = self.sum_config.get("type") or ["FILE", "SHARING"] + if context.type.name not in support_type and context.type.name != "TEXT": return False return True @@ -252,10 +254,14 @@ def _load_config_template(self): plugin_conf = json.load(f) plugin_conf["midjourney"]["enabled"] = False plugin_conf["summary"]["enabled"] = False + plugin_config["linkai"] = plugin_conf return plugin_conf except Exception as e: logger.exception(e) + def reload(self): + self.config = super().load_config() + def _send_info(e_context: EventContext, content: str): reply = Reply(ReplyType.TEXT, content) @@ -275,15 +281,19 @@ def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = Re e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS + def _get_trigger_prefix(): return conf().get("plugin_trigger_prefix", "$") + def _find_sum_id(context): return USER_FILE_MAP.get(_find_user_id(context) + "-sum_id") + def _find_file_id(context): user_id = _find_user_id(context) if user_id: return USER_FILE_MAP.get(user_id + "-file_id") + USER_FILE_MAP = ExpiredDict(conf().get("expires_in_seconds") or 60 * 30) diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 76395bd5a..6500e573d 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -68,7 +68,7 @@ def __str__(self): # midjourney bot class MJBot: def __init__(self, config): - self.base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/img/midjourney" + self.base_url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/img/midjourney" self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} self.config = config self.tasks = {} @@ -88,6 +88,8 @@ def judge_mj_task_type(self, e_context: EventContext): context = e_context['context'] if context.type == ContextType.TEXT: cmd_list = context.content.split(maxsplit=1) + if not cmd_list: + return None if cmd_list[0].lower() == f"{trigger_prefix}mj": return TaskType.GENERATE elif cmd_list[0].lower() == f"{trigger_prefix}mju": diff --git a/plugins/linkai/summary.py b/plugins/linkai/summary.py index c945896b0..84d74bcc6 100644 --- a/plugins/linkai/summary.py +++ b/plugins/linkai/summary.py @@ -2,6 +2,7 @@ from config import conf from common.log import logger import os +import html class LinkSummary: @@ -18,6 +19,7 @@ def summary_file(self, file_path: str): return self._parse_summary_res(res) def summary_url(self, url: str): + url = html.unescape(url) body = { "url": url } @@ -59,7 +61,7 @@ def _parse_summary_res(self, res): return None def base_url(self): - return conf().get("linkai_api_base", "https://api.link-ai.chat") + return conf().get("linkai_api_base", "https://api.link-ai.tech") def headers(self): return {"Authorization": "Bearer " + conf().get("linkai_api_key")} @@ -91,5 +93,4 @@ def check_url(self, url: str): for support_url in support_list: if url.strip().startswith(support_url): return True - logger.debug(f"[LinkSum] unsupported url, no need to process, url={url}") return False diff --git a/plugins/plugin.py b/plugins/plugin.py index 801997b99..028e221db 100644 --- a/plugins/plugin.py +++ b/plugins/plugin.py @@ -18,6 +18,7 @@ def load_config(self) -> dict: if not plugin_conf: # 全局配置不存在,则获取插件目录下的配置 plugin_config_path = os.path.join(self.path, "config.json") + logger.debug(f"loading plugin config, plugin_config_path={plugin_config_path}, exist={os.path.exists(plugin_config_path)}") if os.path.exists(plugin_config_path): with open(plugin_config_path, "r", encoding="utf-8") as f: plugin_conf = json.load(f) @@ -46,3 +47,6 @@ def save_config(self, config: dict): def get_help_text(self, **kwargs): return "暂无帮助信息" + + def reload(self): + pass diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index 49c13ca0e..cecf75d65 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -99,7 +99,7 @@ def scan_plugins(self): try: self.current_plugin_path = plugin_path if plugin_path in self.loaded: - if self.loaded[plugin_path] == None: + if plugin_name.upper() != 'GODCMD': logger.info("reload module %s" % plugin_name) self.loaded[plugin_path] = importlib.reload(sys.modules[import_path]) dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")] @@ -141,19 +141,21 @@ def activate_plugins(self): # 生成新开启的插件实例 failed_plugins = [] for name, plugincls in self.plugins.items(): if plugincls.enabled: - if name not in self.instances: - try: - instance = plugincls() - except Exception as e: - logger.warn("Failed to init %s, diabled. %s" % (name, e)) - self.disable_plugin(name) - failed_plugins.append(name) - continue - self.instances[name] = instance - for event in instance.handlers: - if event not in self.listening_plugins: - self.listening_plugins[event] = [] - self.listening_plugins[event].append(name) + if 'GODCMD' in self.instances and name == 'GODCMD': + continue + # if name not in self.instances: + try: + instance = plugincls() + except Exception as e: + logger.warn("Failed to init %s, diabled. %s" % (name, e)) + self.disable_plugin(name) + failed_plugins.append(name) + continue + self.instances[name] = instance + for event in instance.handlers: + if event not in self.listening_plugins: + self.listening_plugins[event] = [] + self.listening_plugins[event].append(name) self.refresh_order() return failed_plugins diff --git a/plugins/role/role.py b/plugins/role/role.py index c75aa905a..7c7b1067b 100644 --- a/plugins/role/role.py +++ b/plugins/role/role.py @@ -99,7 +99,8 @@ def on_handle_context(self, e_context: EventContext): if e_context["context"].type != ContextType.TEXT: return btype = Bridge().get_bot_type("chat") - if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]: + if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.QWEN_DASHSCOPE, const.XUNFEI, const.BAIDU, const.ZHIPU_AI, const.MOONSHOT, const.MiniMax, const.LINKAI]: + logger.debug(f'不支持的bot: {btype}') return bot = Bridge().get_bot("chat") content = e_context["context"].content[:] diff --git a/plugins/source.json b/plugins/source.json index d53c996ba..3e97bddc5 100644 --- a/plugins/source.json +++ b/plugins/source.json @@ -19,6 +19,26 @@ "Apilot": { "url": "https://github.com/6vision/Apilot.git", "desc": "通过api直接查询早报、热榜、快递、天气等实用信息的插件" + }, + "pictureChange": { + "url": "https://github.com/Yanyutin753/pictureChange.git", + "desc": "1. 支持百度AI和Stable Diffusion WebUI进行图像处理,提供多种模型选择,支持图生图、文生图自定义模板。2. 支持Suno音乐AI可将图像和文字转为音乐。3. 支持自定义模型进行文件、图片总结功能。4. 支持管理员控制群聊内容与参数和功能改变。" + }, + "Blackroom": { + "url": "https://github.com/dividduang/blackroom.git", + "desc": "小黑屋插件,被拉进小黑屋的人将不能使用@bot的功能的插件" + }, + "midjourney": { + "url": "https://github.com/baojingyu/midjourney.git", + "desc": "利用midjourney实现ai绘图的的插件" + }, + "solitaire": { + "url": "https://github.com/Wang-zhechao/solitaire.git", + "desc": "机器人微信接龙插件" + }, + "HighSpeedTicket": { + "url": "https://github.com/He0607/HighSpeedTicket.git", + "desc": "高铁(火车)票查询插件" } } } diff --git a/plugins/tool/README.md b/plugins/tool/README.md index 229da4267..4b3cbcd9d 100644 --- a/plugins/tool/README.md +++ b/plugins/tool/README.md @@ -3,11 +3,19 @@ 使用说明(默认trigger_prefix为$): ```text #help tool: 查看tool帮助信息,可查看已加载工具列表 -$tool 命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。 +$tool 工具名 命令: (pure模式)根据给出的{命令}使用指定 一个 可用工具尽力为你得到结果。 +$tool 命令: (多工具模式)根据给出的{命令}使用 一些 可用工具尽力为你得到结果。 $tool reset: 重置工具。 ``` ### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) +2024.01.16更新 +1. 新增工具pure模式,支持单个工具调用 +2. 新增消息转发工具:email, sms, wechat, 可以根据规则向其他平台发送消息 +3. 替换visual-dl(更名为visual)实现,目前识别图片链接效果较好。 +4. 修复了0.4版本大部分工具返回结果不可靠问题 + +新版本工具名共19个,不一一列举,相应工具需要的环境参数见`tool.py`里的`_build_tool_kwargs`函数 ## 使用说明 使用该插件后将默认使用4个工具, 无需额外配置长期生效: @@ -24,7 +32,7 @@ $tool reset: 重置工具。 > 注1:url-get默认配置、browser需额外配置,browser依赖google-chrome,你需要提前安装好 -> 注2:当检测到长文本时会进入summary tool总结长文本,tokens可能会大量消耗! +> 注2:(可通过`browser_use_summary`或 `url_get_use_summary`开关)当检测到长文本时会进入summary tool总结长文本,tokens可能会大量消耗! 这是debian端安装google-chrome教程,其他系统请自行查找 > https://www.linuxjournal.com/content/how-can-you-install-google-browser-debian @@ -34,9 +42,10 @@ $tool reset: 重置工具。 > terminal调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issue-1659347640 -### 4. meteo-weather +### 4. meteo ###### 回答你有关天气的询问, 需要获取时间、地点上下文信息,本工具使用了[meteo open api](https://open-meteo.com/) 注:该工具需要较高的对话技巧,不保证你问的任何问题均能得到满意的回复 +注2:当前版本可只使用这个工具,返回结果较可控。 > meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334 @@ -65,18 +74,12 @@ $tool reset: 重置工具。 #### 6.2. morning-news * ###### 每日60秒早报,每天凌晨一点更新,本工具使用了[alapi-每日60秒早报](https://alapi.cn/api/view/93) -```text -可配置参数: -1. morning_news_use_llm: 是否使用LLM润色结果,默认false(可能会慢) -``` - > 该tool每天返回内容相同 #### 6.3. finance-news ###### 获取实时的金融财政新闻 -> 该工具需要解决browser tool 的google-chrome依赖安装 - +> 该工具需要用到browser工具解决反爬问题 ### 7. bing-search * @@ -99,18 +102,33 @@ $tool reset: 重置工具。 > 0.4.2更新,例子:帮我找一篇吴恩达写的论文 ### 11. summary -###### 总结工具,该工具必须输入一个本地文件的绝对路径 +###### 总结工具,该工具可以支持输入url > 该工具目前是和其他工具配合使用,暂未测试单独使用效果 -### 12. image2text -###### 将图片转换成文字,底层调用imageCaption模型,该工具必须输入一个本地文件的绝对路径 +### 12. visual +###### 将图片转换成文字,底层调用ali dashscope `qwen-vl-plus`模型 ### 13. searxng-search * ###### 一个私有化的搜索引擎工具 > 安装教程:https://docs.searxng.org/admin/installation.html +### 14. email * +###### 发送邮件 + +### 15. sms * +###### 发送短信 + +### 16. stt * +###### speak to text 语音识别 + +### 17. tts * +###### text to speak 文生语音 + +### 18. wechat * +###### 向好友、群组发送微信 + --- ###### 注1:带*工具需要获取api-key才能使用(在config.json内的kwargs添加项),部分工具需要外网支持 @@ -120,7 +138,7 @@ $tool reset: 重置工具。 ###### 默认工具无需配置,其它工具需手动配置,以增加morning-news和bing-search两个工具为例: ```json { - "tools": ["bing-search", "news", "你想要添加的其他工具"], // 填入你想用到的额外工具名,这里加入了工具"bing-search"和工具"news"(news工具会自动加载morning-news、finance-news等子工具) + "tools": ["bing-search", "morning-news", "你想要添加的其他工具"], // 填入你想用到的额外工具名,这里加入了工具"bing-search"和工具"morning-news" "kwargs": { "debug": true, // 当你遇到问题求助时,需要配置 "request_timeout": 120, // openai接口超时时间 @@ -137,7 +155,6 @@ $tool reset: 重置工具。 - `debug`: 输出chatgpt-tool-hub额外信息用于调试 - `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置 - `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具 - - `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2 - `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认 --- diff --git a/plugins/tool/config.json.template b/plugins/tool/config.json.template index 00d643d7d..aead6570e 100644 --- a/plugins/tool/config.json.template +++ b/plugins/tool/config.json.template @@ -1,12 +1,10 @@ { "tools": [ - "python", "url-get", - "terminal", - "meteo-weather" + "meteo" ], "kwargs": { - "top_k_results": 2, + "debug": false, "no_default": false, "model_name": "gpt-3.5-turbo" } diff --git a/plugins/tool/tool.py b/plugins/tool/tool.py index b99eabbf9..fe36a6836 100644 --- a/plugins/tool/tool.py +++ b/plugins/tool/tool.py @@ -1,23 +1,20 @@ -import json -import os - from chatgpt_tool_hub.apps import AppFactory from chatgpt_tool_hub.apps.app import App -from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names +from chatgpt_tool_hub.tools.tool_register import main_tool_register import plugins from bridge.bridge import Bridge from bridge.context import ContextType from bridge.reply import Reply, ReplyType from common import const -from config import conf +from config import conf, get_appdata_dir from plugins import * @plugins.register( name="tool", desc="Arming your ChatGPT bot with various tools", - version="0.4", + version="0.5", author="goldfishh", desire_priority=0, ) @@ -25,21 +22,25 @@ class Tool(Plugin): def __init__(self): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context - self.app = self._reset_app() - + if not self.tool_config.get("tools"): + logger.warn("[tool] init failed, ignore ") + raise Exception("config.json not found") logger.info("[tool] inited") + def get_help_text(self, verbose=False, **kwargs): help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。" trigger_prefix = conf().get("plugin_trigger_prefix", "$") if not verbose: return help_text help_text += "\n使用说明:\n" - help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n" + help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}模型来选择使用哪些工具尽力为你得到结果。\n" + help_text += f"{trigger_prefix}tool 工具名 " + "命令: 根据给出的{命令}使用指定工具尽力为你得到结果。\n" help_text += f"{trigger_prefix}tool reset: 重置工具。\n\n" + help_text += f"已加载工具列表: \n" - for idx, tool in enumerate(self.app.get_tool_list()): + for idx, tool in enumerate(main_tool_register.get_registered_tool_names()): if idx != 0: help_text += ", " help_text += f"{tool}" @@ -91,17 +92,28 @@ def on_handle_context(self, e_context: EventContext): e_context.action = EventAction.BREAK return - query = content_list[1].strip() + + use_one_tool = False + for tool_name in main_tool_register.get_registered_tool_names(): + if query.startswith(tool_name): + use_one_tool = True + query = query[len(tool_name):] + break # Don't modify bot name all_sessions = Bridge().get_bot("chat").sessions user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages - # chatgpt-tool-hub will reply you with many tools logger.debug("[tool]: just-go") try: - _reply = self.app.ask(query, user_session) + if use_one_tool: + _func, _ = main_tool_register.get_registered_tool()[tool_name] + tool = _func(**self.app_kwargs) + _reply = tool.run(query) + else: + # chatgpt-tool-hub will reply you with many tools + _reply = self.app.ask(query, user_session) e_context.action = EventAction.BREAK_PASS all_sessions.session_reply(_reply, e_context["context"]["session_id"]) except Exception as e: @@ -126,53 +138,111 @@ def _build_tool_kwargs(self, kwargs: dict): request_timeout = kwargs.get("request_timeout") return { - "debug": kwargs.get("debug", False), - "openai_api_key": conf().get("open_ai_api_key", ""), - "open_ai_api_base": conf().get("open_ai_api_base", "https://api.openai.com/v1"), - "deployment_id": conf().get("azure_deployment_id", ""), - "proxy": conf().get("proxy", ""), + # 全局配置相关 + "log": False, # tool 日志开关 + "debug": kwargs.get("debug", False), # 输出更多日志 + "no_default": kwargs.get("no_default", False), # 不要默认的工具,只加载自己导入的工具 + "think_depth": kwargs.get("think_depth", 2), # 一个问题最多使用多少次工具 + "proxy": conf().get("proxy", ""), # 科学上网 "request_timeout": request_timeout if request_timeout else conf().get("request_timeout", 120), + "temperature": kwargs.get("temperature", 0), # llm 温度,建议设置0 + # LLM配置相关 + "llm_api_key": conf().get("open_ai_api_key", ""), # 如果llm api用key鉴权,传入这里 + "llm_api_base_url": conf().get("open_ai_api_base", "https://api.openai.com/v1"), # 支持openai接口的llm服务地址前缀 + "deployment_id": conf().get("azure_deployment_id", ""), # azure openai会用到 # note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置 - "model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"), - "no_default": kwargs.get("no_default", False), - "top_k_results": kwargs.get("top_k_results", 3), - # for news tool - "news_api_key": kwargs.get("news_api_key", ""), + "model_name": tool_model_name if tool_model_name else conf().get("model", const.GPT35), + # 工具配置相关 + # for arxiv tool + "arxiv_simple": kwargs.get("arxiv_simple", True), # 返回内容更精简 + "arxiv_top_k_results": kwargs.get("arxiv_top_k_results", 2), # 只返回前k个搜索结果 + "arxiv_sort_by": kwargs.get("arxiv_sort_by", "relevance"), # 搜索排序方式 ["relevance","lastUpdatedDate","submittedDate"] + "arxiv_sort_order": kwargs.get("arxiv_sort_order", "descending"), # 搜索排序方式 ["ascending", "descending"] + "arxiv_output_type": kwargs.get("arxiv_output_type", "text"), # 搜索结果类型 ["text", "pdf", "all"] # for bing-search tool "bing_subscription_key": kwargs.get("bing_subscription_key", ""), + "bing_search_url": kwargs.get("bing_search_url", "https://api.bing.microsoft.com/v7.0/search"), # 必应搜索的endpoint地址,无需修改 + "bing_search_top_k_results": kwargs.get("bing_search_top_k_results", 2), # 只返回前k个搜索结果 + "bing_search_simple": kwargs.get("bing_search_simple", True), # 返回内容更精简 + "bing_search_output_type": kwargs.get("bing_search_output_type", "text"), # 搜索结果类型 ["text", "json"] + # for email tool + "email_nickname_mapping": kwargs.get("email_nickname_mapping", "{}"), # 关于人的代号对应的邮箱地址,可以不输入邮箱地址发送邮件。键为代号值为邮箱地址 + "email_smtp_host": kwargs.get("email_smtp_host", ""), # 例如 'smtp.qq.com' + "email_smtp_port": kwargs.get("email_smtp_port", ""), # 例如 587 + "email_sender": kwargs.get("email_sender", ""), # 发送者的邮件地址 + "email_authorization_code": kwargs.get("email_authorization_code", ""), # 发送者验证秘钥(可能不是登录密码) # for google-search tool "google_api_key": kwargs.get("google_api_key", ""), "google_cse_id": kwargs.get("google_cse_id", ""), + "google_simple": kwargs.get("google_simple", True), # 返回内容更精简 + "google_output_type": kwargs.get("google_output_type", "text"), # 搜索结果类型 ["text", "json"] + # for finance-news tool + "finance_news_filter": kwargs.get("finance_news_filter", False), # 是否开启过滤 + "finance_news_filter_list": kwargs.get("finance_news_filter_list", []), # 过滤词列表 + "finance_news_simple": kwargs.get("finance_news_simple", True), # 返回内容更精简 + "finance_news_repeat_news": kwargs.get("finance_news_repeat_news", False), # 是否过滤不返回。该tool每次返回约50条新闻,可能有重复新闻 + # for morning-news tool + "morning_news_api_key": kwargs.get("morning_news_api_key", ""), # api-key + "morning_news_simple": kwargs.get("morning_news_simple", True), # 返回内容更精简 + "morning_news_output_type": kwargs.get("morning_news_output_type", "text"), # 搜索结果类型 ["text", "image"] + # for news-api tool + "news_api_key": kwargs.get("news_api_key", ""), # for searxng-search tool - "searx_search_host": kwargs.get("searx_search_host", ""), + "searxng_search_host": kwargs.get("searxng_search_host", ""), + "searxng_search_top_k_results": kwargs.get("searxng_search_top_k_results", 2), # 只返回前k个搜索结果 + "searxng_search_output_type": kwargs.get("searxng_search_output_type", "text"), # 搜索结果类型 ["text", "json"] + # for sms tool + "sms_nickname_mapping": kwargs.get("sms_nickname_mapping", "{}"), # 关于人的代号对应的手机号,可以不输入手机号发送sms。键为代号值为手机号 + "sms_username": kwargs.get("sms_username", ""), # smsbao用户名 + "sms_apikey": kwargs.get("sms_apikey", ""), # smsbao + # for stt tool + "stt_api_key": kwargs.get("stt_api_key", ""), # azure + "stt_api_region": kwargs.get("stt_api_region", ""), # azure + "stt_recognition_language": kwargs.get("stt_recognition_language", "zh-CN"), # 识别的语言类型 部分:en-US ja-JP ko-KR yue-CN zh-CN + # for tts tool + "tts_api_key": kwargs.get("tts_api_key", ""), # azure + "tts_api_region": kwargs.get("tts_api_region", ""), # azure + "tts_auto_detect": kwargs.get("tts_auto_detect", True), # 是否自动检测语音的语言 + "tts_speech_id": kwargs.get("tts_speech_id", "zh-CN-XiaozhenNeural"), # 输出语音ID + # for summary tool + "summary_max_segment_length": kwargs.get("summary_max_segment_length", 2500), # 每2500tokens分段,多段触发总结tool + # for terminal tool + "terminal_nsfc_filter": kwargs.get("terminal_nsfc_filter", True), # 是否过滤llm输出的危险命令 + "terminal_return_err_output": kwargs.get("terminal_return_err_output", True), # 是否输出错误信息 + "terminal_timeout": kwargs.get("terminal_timeout", 20), # 允许命令最长执行时间 + # for visual tool + "caption_api_key": kwargs.get("caption_api_key", ""), # ali dashscope apikey + # for browser tool + "browser_use_summary": kwargs.get("browser_use_summary", True), # 是否对返回结果使用tool功能 + # for url-get tool + "url_get_use_summary": kwargs.get("url_get_use_summary", True), # 是否对返回结果使用tool功能 + # for wechat tool + "wechat_hot_reload": kwargs.get("wechat_hot_reload", True), # 是否使用热重载的方式发送wechat + "wechat_cpt_path": kwargs.get("wechat_cpt_path", os.path.join(get_appdata_dir(), "itchat.pkl")), # wechat 配置文件(`itchat.pkl`) + "wechat_send_group": kwargs.get("wechat_send_group", False), # 是否向群组发送消息 + "wechat_nickname_mapping": kwargs.get("wechat_nickname_mapping", "{}"), # 关于人的代号映射关系。键为代号值为微信名(昵称、备注名均可) + # for wikipedia tool + "wikipedia_top_k_results": kwargs.get("wikipedia_top_k_results", 2), # 只返回前k个搜索结果 # for wolfram-alpha tool "wolfram_alpha_appid": kwargs.get("wolfram_alpha_appid", ""), - # for morning-news tool - "morning_news_api_key": kwargs.get("morning_news_api_key", ""), - # for visual_dl tool - "cuda_device": kwargs.get("cuda_device", "cpu"), - "think_depth": kwargs.get("think_depth", 3), - "arxiv_summary": kwargs.get("arxiv_summary", True), - "morning_news_use_llm": kwargs.get("morning_news_use_llm", False), } def _filter_tool_list(self, tool_list: list): valid_list = [] for tool in tool_list: - if tool in get_all_tool_names(): + if tool in main_tool_register.get_registered_tool_names(): valid_list.append(tool) else: logger.warning("[tool] filter invalid tool: " + repr(tool)) return valid_list def _reset_app(self) -> App: - tool_config = self._read_json() - app_kwargs = self._build_tool_kwargs(tool_config.get("kwargs", {})) + self.tool_config = self._read_json() + self.app_kwargs = self._build_tool_kwargs(self.tool_config.get("kwargs", {})) app = AppFactory() - app.init_env(**app_kwargs) - + app.init_env(**self.app_kwargs) # filter not support tool - tool_list = self._filter_tool_list(tool_config.get("tools", [])) + tool_list = self._filter_tool_list(self.tool_config.get("tools", [])) - return app.create_app(tools_list=tool_list, **app_kwargs) + return app.create_app(tools_list=tool_list, **self.app_kwargs) diff --git a/requirements-optional.txt b/requirements-optional.txt index 563327428..f158d330f 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -7,26 +7,40 @@ gTTS>=2.3.1 # google text to speech pyttsx3>=2.90 # pytsx text to speech baidu_aip>=4.16.10 # baidu voice azure-cognitiveservices-speech # azure voice +edge-tts # edge-tts numpy<=1.24.2 langid # language detect +elevenlabs==1.0.3 # elevenlabs TTS #install plugin dulwich -# wechaty -wechaty>=0.10.7 -wechaty_puppet>=0.4.23 -# pysilk_mod>=1.6.0 # needed by send voice only in wechaty - -# wechatmp wechatcom +# wechatmp && wechatcom web.py wechatpy # chatgpt-tool-hub plugin -chatgpt_tool_hub==0.4.6 +chatgpt_tool_hub==0.5.0 # xunfei spark websocket-client==1.2.0 # claude bot curl_cffi +# claude API +anthropic + +# tongyi qwen +broadscope_bailian + +# google +google-generativeai + +# dingtalk +dingtalk_stream + +# zhipuai +zhipuai>=2.0.1 + +# tongyi qwen new sdk +dashscope diff --git a/requirements.txt b/requirements.txt index c032e0886..917f36e06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ chardet>=5.1.0 Pillow pre-commit web.py +linkai>=0.0.6.0 diff --git a/voice/ali/ali_api.py b/voice/ali/ali_api.py new file mode 100644 index 000000000..def5c7add --- /dev/null +++ b/voice/ali/ali_api.py @@ -0,0 +1,216 @@ +# coding=utf-8 +""" +Author: chazzjimel +Email: chazzjimel@gmail.com +wechat:cheung-z-x + +Description: + +""" + +import http.client +import json +import time +import requests +import datetime +import hashlib +import hmac +import base64 +import urllib.parse +import uuid + +from common.log import logger +from common.tmp_dir import TmpDir + + +def text_to_speech_aliyun(url, text, appkey, token): + """ + 使用阿里云的文本转语音服务将文本转换为语音。 + + 参数: + - url (str): 阿里云文本转语音服务的端点URL。 + - text (str): 要转换为语音的文本。 + - appkey (str): 您的阿里云appkey。 + - token (str): 阿里云API的认证令牌。 + + 返回值: + - str: 成功时输出音频文件的路径,否则为None。 + """ + headers = { + "Content-Type": "application/json", + } + + data = { + "text": text, + "appkey": appkey, + "token": token, + "format": "wav" + } + + response = requests.post(url, headers=headers, data=json.dumps(data)) + + if response.status_code == 200 and response.headers['Content-Type'] == 'audio/mpeg': + output_file = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav" + + with open(output_file, 'wb') as file: + file.write(response.content) + logger.debug(f"音频文件保存成功,文件名:{output_file}") + else: + logger.debug("响应状态码: {}".format(response.status_code)) + logger.debug("响应内容: {}".format(response.text)) + output_file = None + + return output_file + +def speech_to_text_aliyun(url, audioContent, appkey, token): + """ + 使用阿里云的语音识别服务识别音频文件中的语音。 + + 参数: + - url (str): 阿里云语音识别服务的端点URL。 + - audioContent (byte): pcm音频数据。 + - appkey (str): 您的阿里云appkey。 + - token (str): 阿里云API的认证令牌。 + + 返回值: + - str: 成功时输出识别到的文本,否则为None。 + """ + format = 'pcm' + sample_rate = 16000 + enablePunctuationPrediction = True + enableInverseTextNormalization = True + enableVoiceDetection = False + + # 设置RESTful请求参数 + request = url + '?appkey=' + appkey + request = request + '&format=' + format + request = request + '&sample_rate=' + str(sample_rate) + + if enablePunctuationPrediction : + request = request + '&enable_punctuation_prediction=' + 'true' + + if enableInverseTextNormalization : + request = request + '&enable_inverse_text_normalization=' + 'true' + + if enableVoiceDetection : + request = request + '&enable_voice_detection=' + 'true' + + host = 'nls-gateway-cn-shanghai.aliyuncs.com' + + # 设置HTTPS请求头部 + httpHeaders = { + 'X-NLS-Token': token, + 'Content-type': 'application/octet-stream', + 'Content-Length': len(audioContent) + } + + conn = http.client.HTTPSConnection(host) + conn.request(method='POST', url=request, body=audioContent, headers=httpHeaders) + + response = conn.getresponse() + body = response.read() + try: + body = json.loads(body) + status = body['status'] + if status == 20000000 : + result = body['result'] + if result : + logger.info(f"阿里云语音识别到了:{result}") + conn.close() + return result + else : + logger.error(f"语音识别失败,状态码: {status}") + except ValueError: + logger.error(f"语音识别失败,收到非JSON格式的数据: {body}") + conn.close() + return None + + +class AliyunTokenGenerator: + """ + 用于生成阿里云服务认证令牌的类。 + + 属性: + - access_key_id (str): 您的阿里云访问密钥ID。 + - access_key_secret (str): 您的阿里云访问密钥秘密。 + """ + + def __init__(self, access_key_id, access_key_secret): + self.access_key_id = access_key_id + self.access_key_secret = access_key_secret + + def sign_request(self, parameters): + """ + 为阿里云服务签名请求。 + + 参数: + - parameters (dict): 请求的参数字典。 + + 返回值: + - str: 请求的签名签章。 + """ + # 将参数按照字典顺序排序 + sorted_params = sorted(parameters.items()) + + # 构造待签名的查询字符串 + canonicalized_query_string = '' + for (k, v) in sorted_params: + canonicalized_query_string += '&' + self.percent_encode(k) + '=' + self.percent_encode(v) + + # 构造用于签名的字符串 + string_to_sign = 'GET&%2F&' + self.percent_encode(canonicalized_query_string[1:]) # 使用GET方法 + + # 使用HMAC算法计算签名 + h = hmac.new((self.access_key_secret + "&").encode('utf-8'), string_to_sign.encode('utf-8'), hashlib.sha1) + signature = base64.encodebytes(h.digest()).strip() + + return signature + + def percent_encode(self, encode_str): + """ + 对字符串进行百分比编码。 + + 参数: + - encode_str (str): 要编码的字符串。 + + 返回值: + - str: 编码后的字符串。 + """ + encode_str = str(encode_str) + res = urllib.parse.quote(encode_str, '') + res = res.replace('+', '%20') + res = res.replace('*', '%2A') + res = res.replace('%7E', '~') + return res + + def get_token(self): + """ + 获取阿里云服务的令牌。 + + 返回值: + - str: 获取到的令牌。 + """ + # 设置请求参数 + params = { + 'Format': 'JSON', + 'Version': '2019-02-28', + 'AccessKeyId': self.access_key_id, + 'SignatureMethod': 'HMAC-SHA1', + 'Timestamp': datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"), + 'SignatureVersion': '1.0', + 'SignatureNonce': str(uuid.uuid4()), # 使用uuid生成唯一的随机数 + 'Action': 'CreateToken', + 'RegionId': 'cn-shanghai' + } + + # 计算签名 + signature = self.sign_request(params) + params['Signature'] = signature + + # 构造请求URL + url = 'http://nls-meta.cn-shanghai.aliyuncs.com/?' + urllib.parse.urlencode(params) + + # 发送请求 + response = requests.get(url) + + return response.text diff --git a/voice/ali/ali_voice.py b/voice/ali/ali_voice.py new file mode 100644 index 000000000..43ea0b46f --- /dev/null +++ b/voice/ali/ali_voice.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +""" +Author: chazzjimel +Email: chazzjimel@gmail.com +wechat:cheung-z-x + +Description: +ali voice service + +""" +import json +import os +import re +import time + +from bridge.reply import Reply, ReplyType +from common.log import logger +from voice.audio_convert import get_pcm_from_wav +from voice.voice import Voice +from voice.ali.ali_api import AliyunTokenGenerator, speech_to_text_aliyun, text_to_speech_aliyun +from config import conf + + +class AliVoice(Voice): + def __init__(self): + """ + 初始化AliVoice类,从配置文件加载必要的配置。 + """ + try: + curdir = os.path.dirname(__file__) + config_path = os.path.join(curdir, "config.json") + with open(config_path, "r") as fr: + config = json.load(fr) + self.token = None + self.token_expire_time = 0 + # 默认复用阿里云千问的 access_key 和 access_secret + self.api_url_voice_to_text = config.get("api_url_voice_to_text") + self.api_url_text_to_voice = config.get("api_url_text_to_voice") + self.app_key = config.get("app_key") + self.access_key_id = conf().get("qwen_access_key_id") or config.get("access_key_id") + self.access_key_secret = conf().get("qwen_access_key_secret") or config.get("access_key_secret") + except Exception as e: + logger.warn("AliVoice init failed: %s, ignore " % e) + + def textToVoice(self, text): + """ + 将文本转换为语音文件。 + + :param text: 要转换的文本。 + :return: 返回一个Reply对象,其中包含转换得到的语音文件或错误信息。 + """ + # 清除文本中的非中文、非英文和非基本字符 + text = re.sub(r'[^\u4e00-\u9fa5\u3040-\u30FF\uAC00-\uD7AFa-zA-Z0-9' + r'äöüÄÖÜáéíóúÁÉÍÓÚàèìòùÀÈÌÒÙâêîôûÂÊÎÔÛçÇñÑ,。!?,.]', '', text) + # 提取有效的token + token_id = self.get_valid_token() + fileName = text_to_speech_aliyun(self.api_url_text_to_voice, text, self.app_key, token_id) + if fileName: + logger.info("[Ali] textToVoice text={} voice file name={}".format(text, fileName)) + reply = Reply(ReplyType.VOICE, fileName) + else: + reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") + return reply + + def voiceToText(self, voice_file): + """ + 将语音文件转换为文本。 + + :param voice_file: 要转换的语音文件。 + :return: 返回一个Reply对象,其中包含转换得到的文本或错误信息。 + """ + # 提取有效的token + token_id = self.get_valid_token() + logger.debug("[Ali] voice file name={}".format(voice_file)) + pcm = get_pcm_from_wav(voice_file) + text = speech_to_text_aliyun(self.api_url_voice_to_text, pcm, self.app_key, token_id) + if text: + logger.info("[Ali] VoicetoText = {}".format(text)) + reply = Reply(ReplyType.TEXT, text) + else: + reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") + return reply + + def get_valid_token(self): + """ + 获取有效的阿里云token。 + + :return: 返回有效的token字符串。 + """ + current_time = time.time() + if self.token is None or current_time >= self.token_expire_time: + get_token = AliyunTokenGenerator(self.access_key_id, self.access_key_secret) + token_str = get_token.get_token() + token_data = json.loads(token_str) + self.token = token_data["Token"]["Id"] + # 将过期时间减少一小段时间(例如5分钟),以避免在边界条件下的过期 + self.token_expire_time = token_data["Token"]["ExpireTime"] - 300 + logger.debug(f"新获取的阿里云token:{self.token}") + else: + logger.debug("使用缓存的token") + return self.token diff --git a/voice/ali/config.json.template b/voice/ali/config.json.template new file mode 100644 index 000000000..563c57f0e --- /dev/null +++ b/voice/ali/config.json.template @@ -0,0 +1,7 @@ +{ + "api_url_text_to_voice": "https://nls-gateway-cn-shanghai.aliyuncs.com/stream/v1/tts", + "api_url_voice_to_text": "https://nls-gateway.cn-shanghai.aliyuncs.com/stream/v1/asr", + "app_key": "", + "access_key_id": "", + "access_key_secret": "" +} \ No newline at end of file diff --git a/voice/audio_convert.py b/voice/audio_convert.py index 18fe3c2f3..426367883 100644 --- a/voice/audio_convert.py +++ b/voice/audio_convert.py @@ -6,7 +6,7 @@ try: import pysilk except ImportError: - logger.warn("import pysilk failed, wechaty voice message will not be supported.") + logger.debug("import pysilk failed, wechaty voice message will not be supported.") from pydub import AudioSegment @@ -64,7 +64,9 @@ def any_to_wav(any_path, wav_path): if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): return sil_to_wav(any_path, wav_path) audio = AudioSegment.from_file(any_path) - audio.export(wav_path, format="wav") + audio.set_frame_rate(8000) # 百度语音转写支持8000采样率, pcm_s16le, 单通道语音识别 + audio.set_channels(1) + audio.export(wav_path, format="wav", codec='pcm_s16le') def any_to_sil(any_path, sil_path): diff --git a/voice/azure/azure_voice.py b/voice/azure/azure_voice.py index b5884ed4f..b0ec8b81b 100644 --- a/voice/azure/azure_voice.py +++ b/voice/azure/azure_voice.py @@ -65,7 +65,7 @@ def voiceToText(self, voice_file): reply = Reply(ReplyType.TEXT, result.text) else: cancel_details = result.cancellation_details - logger.error("[Azure] voiceToText error, result={}, errordetails={}".format(result, cancel_details.error_details)) + logger.error("[Azure] voiceToText error, result={}, errordetails={}".format(result, cancel_details)) reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") return reply diff --git a/voice/edge/edge_voice.py b/voice/edge/edge_voice.py new file mode 100644 index 000000000..7bb8b2e6a --- /dev/null +++ b/voice/edge/edge_voice.py @@ -0,0 +1,50 @@ +import time + +import edge_tts +import asyncio + +from bridge.reply import Reply, ReplyType +from common.log import logger +from common.tmp_dir import TmpDir +from voice.voice import Voice + + +class EdgeVoice(Voice): + + def __init__(self): + ''' + # 普通话 + zh-CN-XiaoxiaoNeural + zh-CN-XiaoyiNeural + zh-CN-YunjianNeural + zh-CN-YunxiNeural + zh-CN-YunxiaNeural + zh-CN-YunyangNeural + # 地方口音 + zh-CN-liaoning-XiaobeiNeural + zh-CN-shaanxi-XiaoniNeural + # 粤语 + zh-HK-HiuGaaiNeural + zh-HK-HiuMaanNeural + zh-HK-WanLungNeural + # 湾湾腔 + zh-TW-HsiaoChenNeural + zh-TW-HsiaoYuNeural + zh-TW-YunJheNeural + ''' + self.voice = "zh-CN-YunjianNeural" + + def voiceToText(self, voice_file): + pass + + async def gen_voice(self, text, fileName): + communicate = edge_tts.Communicate(text, self.voice) + await communicate.save(fileName) + + def textToVoice(self, text): + fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3" + + asyncio.run(self.gen_voice(text, fileName)) + + logger.info("[EdgeTTS] textToVoice text={} voice file name={}".format(text, fileName)) + return Reply(ReplyType.VOICE, fileName) diff --git a/voice/elevent/elevent_voice.py b/voice/elevent/elevent_voice.py index 15936ab96..2cfa5a3fc 100644 --- a/voice/elevent/elevent_voice.py +++ b/voice/elevent/elevent_voice.py @@ -1,7 +1,7 @@ import time -from elevenlabs import set_api_key,generate - +from elevenlabs.client import ElevenLabs +from elevenlabs import save from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir @@ -9,7 +9,7 @@ from config import conf XI_API_KEY = conf().get("xi_api_key") -set_api_key(XI_API_KEY) +client = ElevenLabs(api_key=XI_API_KEY) name = conf().get("xi_voice_id") class ElevenLabsVoice(Voice): @@ -21,13 +21,12 @@ def voiceToText(self, voice_file): pass def textToVoice(self, text): - audio = generate( + audio = client.generate( text=text, voice=name, - model='eleven_multilingual_v1' + model='eleven_multilingual_v2' ) fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3" - with open(fileName, "wb") as f: - f.write(audio) + save(audio, fileName) logger.info("[ElevenLabs] textToVoice text={} voice file name={}".format(text, fileName)) return Reply(ReplyType.VOICE, fileName) \ No newline at end of file diff --git a/voice/factory.py b/voice/factory.py index 8725e29d5..fa8b79e50 100644 --- a/voice/factory.py +++ b/voice/factory.py @@ -36,5 +36,18 @@ def create_voice(voice_type): elif voice_type == "linkai": from voice.linkai.linkai_voice import LinkAIVoice + return LinkAIVoice() + elif voice_type == "ali": + from voice.ali.ali_voice import AliVoice + + return AliVoice() + elif voice_type == "edge": + from voice.edge.edge_voice import EdgeVoice + + return EdgeVoice() + elif voice_type == "xunfei": + from voice.xunfei.xunfei_voice import XunfeiVoice + + return XunfeiVoice() raise RuntimeError diff --git a/voice/linkai/linkai_voice.py b/voice/linkai/linkai_voice.py index 7dc420c73..739b5f605 100644 --- a/voice/linkai/linkai_voice.py +++ b/voice/linkai/linkai_voice.py @@ -19,15 +19,18 @@ def __init__(self): def voiceToText(self, voice_file): logger.debug("[LinkVoice] voice file name={}".format(voice_file)) try: - url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/transcriptions" + url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/audio/transcriptions" headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} model = None if not conf().get("text_to_voice") or conf().get("voice_to_text") == "openai": model = const.WHISPER_1 if voice_file.endswith(".amr"): - mp3_file = os.path.splitext(voice_file)[0] + ".mp3" - audio_convert.any_to_mp3(voice_file, mp3_file) - voice_file = mp3_file + try: + mp3_file = os.path.splitext(voice_file)[0] + ".mp3" + audio_convert.any_to_mp3(voice_file, mp3_file) + voice_file = mp3_file + except Exception as e: + logger.warn(f"[LinkVoice] amr file transfer failed, directly send amr voice file: {format(e)}") file = open(voice_file, "rb") file_body = { "file": file @@ -46,12 +49,12 @@ def voiceToText(self, voice_file): logger.info(f"[LinkVoice] voiceToText success, text={text}, file name={voice_file}") except Exception as e: logger.error(e) - reply = Reply(ReplyType.ERROR, "我暂时还无法听清您的语音,请稍后再试吧~") + return None return reply def textToVoice(self, text): try: - url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/speech" + url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/audio/speech" headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} model = const.TTS_1 if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: @@ -59,7 +62,8 @@ def textToVoice(self, text): data = { "model": model, "input": text, - "voice": conf().get("tts_voice_id") + "voice": conf().get("tts_voice_id"), + "app_code": conf().get("linkai_app_code") } res = requests.post(url, headers=headers, json=data, timeout=(5, 120)) if res.status_code == 200: @@ -75,5 +79,5 @@ def textToVoice(self, text): return None except Exception as e: logger.error(e) - reply = Reply(ReplyType.ERROR, "遇到了一点小问题,请稍后再问我吧") - return reply + # reply = Reply(ReplyType.ERROR, "遇到了一点小问题,请稍后再问我吧") + return None diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 2dd3cbefb..506d8b5c7 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -21,8 +21,21 @@ def voiceToText(self, voice_file): logger.debug("[Openai] voice file name={}".format(voice_file)) try: file = open(voice_file, "rb") - result = openai.Audio.transcribe("whisper-1", file) - text = result["text"] + api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1" + url = f'{api_base}/audio/transcriptions' + headers = { + 'Authorization': 'Bearer ' + conf().get("open_ai_api_key"), + # 'Content-Type': 'multipart/form-data' # 加了会报错,不知道什么原因 + } + files = { + "file": file, + } + data = { + "model": "whisper-1", + } + response = requests.post(url, headers=headers, files=files, data=data) + response_data = response.json() + text = response_data['text'] reply = Reply(ReplyType.TEXT, text) logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file)) except Exception as e: @@ -33,7 +46,8 @@ def voiceToText(self, voice_file): def textToVoice(self, text): try: - url = 'https://api.openai.com/v1/audio/speech' + api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1" + url = f'{api_base}/audio/speech' headers = { 'Authorization': 'Bearer ' + conf().get("open_ai_api_key"), 'Content-Type': 'application/json' diff --git a/voice/xunfei/config.json.template b/voice/xunfei/config.json.template new file mode 100644 index 000000000..770160dfc --- /dev/null +++ b/voice/xunfei/config.json.template @@ -0,0 +1,7 @@ +{ + "APPID":"xxx71xxx", + "APIKey":"xxxx69058exxxxxx", + "APISecret":"xxxx697f0xxxxxx", + "BusinessArgsTTS":{"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"}, + "BusinessArgsASR":{"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vad_eos":10000, "dwa": "wpgs"} +} diff --git a/voice/xunfei/xunfei_asr.py b/voice/xunfei/xunfei_asr.py new file mode 100644 index 000000000..2e81369f1 --- /dev/null +++ b/voice/xunfei/xunfei_asr.py @@ -0,0 +1,209 @@ +# -*- coding:utf-8 -*- +# +# Author: njnuko +# Email: njnuko@163.com +# +# 这个文档是基于官方的demo来改的,固体官方demo文档请参考官网 +# +# 语音听写流式 WebAPI 接口调用示例 接口文档(必看):https://doc.xfyun.cn/rest_api/语音听写(流式版).html +# webapi 听写服务参考帖子(必看):http://bbs.xfyun.cn/forum.php?mod=viewthread&tid=38947&extra= +# 语音听写流式WebAPI 服务,热词使用方式:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--个性化热词, +# 设置热词 +# 注意:热词只能在识别的时候会增加热词的识别权重,需要注意的是增加相应词条的识别率,但并不是绝对的,具体效果以您测试为准。 +# 语音听写流式WebAPI 服务,方言试用方法:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--识别语种列表 +# 可添加语种或方言,添加后会显示该方言的参数值 +# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # + +import websocket +import datetime +import hashlib +import base64 +import hmac +import json +from urllib.parse import urlencode +import time +import ssl +from wsgiref.handlers import format_date_time +from datetime import datetime +from time import mktime +import _thread as thread +import os +import wave + + +STATUS_FIRST_FRAME = 0 # 第一帧的标识 +STATUS_CONTINUE_FRAME = 1 # 中间帧标识 +STATUS_LAST_FRAME = 2 # 最后一帧的标识 + +############# +#whole_dict 是用来存储返回值的,由于带语音修正,所以用dict来存储,有更新的化pop之前的值,最后再合并 +global whole_dict +#这个文档是官方文档改的,这个参数是用来做函数调用时用的 +global wsParam +############## + + +class Ws_Param(object): + # 初始化 + def __init__(self, APPID, APIKey, APISecret,BusinessArgs, AudioFile): + self.APPID = APPID + self.APIKey = APIKey + self.APISecret = APISecret + self.AudioFile = AudioFile + self.BusinessArgs = BusinessArgs + # 公共参数(common) + self.CommonArgs = {"app_id": self.APPID} + # 业务参数(business),更多个性化参数可在官网查看 + #self.BusinessArgs = {"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vinfo":1,"vad_eos":10000} + + # 生成url + def create_url(self): + url = 'wss://ws-api.xfyun.cn/v2/iat' + # 生成RFC1123格式的时间戳 + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + # 拼接字符串 + signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/iat " + "HTTP/1.1" + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + self.APIKey, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": "ws-api.xfyun.cn" + } + # 拼接鉴权参数,生成url + url = url + '?' + urlencode(v) + #print("date: ",date) + #print("v: ",v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + #print('websocket url :', url) + return url + + +# 收到websocket消息的处理 +def on_message(ws, message): + global whole_dict + try: + code = json.loads(message)["code"] + sid = json.loads(message)["sid"] + if code != 0: + errMsg = json.loads(message)["message"] + print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) + else: + temp1 = json.loads(message)["data"]["result"] + data = json.loads(message)["data"]["result"]["ws"] + sn = temp1["sn"] + if "rg" in temp1.keys(): + rep = temp1["rg"] + rep_start = rep[0] + rep_end = rep[1] + for sn in range(rep_start,rep_end+1): + #print("before pop",whole_dict) + #print("sn",sn) + whole_dict.pop(sn,None) + #print("after pop",whole_dict) + results = "" + for i in data: + for w in i["cw"]: + results += w["w"] + whole_dict[sn]=results + #print("after add",whole_dict) + else: + results = "" + for i in data: + for w in i["cw"]: + results += w["w"] + whole_dict[sn]=results + #print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False))) + except Exception as e: + print("receive msg,but parse exception:", e) + + + +# 收到websocket错误的处理 +def on_error(ws, error): + print("### error:", error) + + +# 收到websocket关闭的处理 +def on_close(ws,a,b): + print("### closed ###") + + +# 收到websocket连接建立的处理 +def on_open(ws): + global wsParam + def run(*args): + frameSize = 8000 # 每一帧的音频大小 + intervel = 0.04 # 发送音频间隔(单位:s) + status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧 + + with wave.open(wsParam.AudioFile, "rb") as fp: + while True: + buf = fp.readframes(frameSize) + # 文件结束 + if not buf: + status = STATUS_LAST_FRAME + # 第一帧处理 + # 发送第一帧音频,带business 参数 + # appid 必须带上,只需第一帧发送 + if status == STATUS_FIRST_FRAME: + d = {"common": wsParam.CommonArgs, + "business": wsParam.BusinessArgs, + "data": {"status": 0, "format": "audio/L16;rate=16000","audio": str(base64.b64encode(buf), 'utf-8'), "encoding": "raw"}} + d = json.dumps(d) + ws.send(d) + status = STATUS_CONTINUE_FRAME + # 中间帧处理 + elif status == STATUS_CONTINUE_FRAME: + d = {"data": {"status": 1, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "raw"}} + ws.send(json.dumps(d)) + # 最后一帧处理 + elif status == STATUS_LAST_FRAME: + d = {"data": {"status": 2, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "raw"}} + ws.send(json.dumps(d)) + time.sleep(1) + break + # 模拟音频采样间隔 + time.sleep(intervel) + ws.close() + + thread.start_new_thread(run, ()) + +#提供给xunfei_voice调用的函数 +def xunfei_asr(APPID,APISecret,APIKey,BusinessArgsASR,AudioFile): + global whole_dict + global wsParam + whole_dict = {} + wsParam1 = Ws_Param(APPID=APPID, APISecret=APISecret, + APIKey=APIKey,BusinessArgs=BusinessArgsASR, + AudioFile=AudioFile) + #wsParam是global变量,给上面on_open函数调用使用的 + wsParam = wsParam1 + websocket.enableTrace(False) + wsUrl = wsParam.create_url() + ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close) + ws.on_open = on_open + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) + #把字典的值合并起来做最后识别的输出 + whole_words = "" + for i in sorted(whole_dict.keys()): + whole_words += whole_dict[i] + return whole_words + + diff --git a/voice/xunfei/xunfei_tts.py b/voice/xunfei/xunfei_tts.py new file mode 100644 index 000000000..33f934cfc --- /dev/null +++ b/voice/xunfei/xunfei_tts.py @@ -0,0 +1,163 @@ +# -*- coding:utf-8 -*- +# +# Author: njnuko +# Email: njnuko@163.com +# +# 这个文档是基于官方的demo来改的,固体官方demo文档请参考官网 +# +# 语音听写流式 WebAPI 接口调用示例 接口文档(必看):https://doc.xfyun.cn/rest_api/语音听写(流式版).html +# webapi 听写服务参考帖子(必看):http://bbs.xfyun.cn/forum.php?mod=viewthread&tid=38947&extra= +# 语音听写流式WebAPI 服务,热词使用方式:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--个性化热词, +# 设置热词 +# 注意:热词只能在识别的时候会增加热词的识别权重,需要注意的是增加相应词条的识别率,但并不是绝对的,具体效果以您测试为准。 +# 语音听写流式WebAPI 服务,方言试用方法:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--识别语种列表 +# 可添加语种或方言,添加后会显示该方言的参数值 +# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +import websocket +import datetime +import hashlib +import base64 +import hmac +import json +from urllib.parse import urlencode +import time +import ssl +from wsgiref.handlers import format_date_time +from datetime import datetime +from time import mktime +import _thread as thread +import os + + + +STATUS_FIRST_FRAME = 0 # 第一帧的标识 +STATUS_CONTINUE_FRAME = 1 # 中间帧标识 +STATUS_LAST_FRAME = 2 # 最后一帧的标识 + +############# +#这个参数是用来做输出文件路径的 +global outfile +#这个文档是官方文档改的,这个参数是用来做函数调用时用的 +global wsParam +############## + + +class Ws_Param(object): + # 初始化 + def __init__(self, APPID, APIKey, APISecret,BusinessArgs,Text): + self.APPID = APPID + self.APIKey = APIKey + self.APISecret = APISecret + self.BusinessArgs = BusinessArgs + self.Text = Text + + # 公共参数(common) + self.CommonArgs = {"app_id": self.APPID} + # 业务参数(business),更多个性化参数可在官网查看 + #self.BusinessArgs = {"aue": "raw", "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"} + self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-8')), "UTF8")} + #使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"” + #self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")} + + # 生成url + def create_url(self): + url = 'wss://tts-api.xfyun.cn/v2/tts' + # 生成RFC1123格式的时间戳 + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + # 拼接字符串 + signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + self.APIKey, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": "ws-api.xfyun.cn" + } + # 拼接鉴权参数,生成url + url = url + '?' + urlencode(v) + # print("date: ",date) + # print("v: ",v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + # print('websocket url :', url) + return url + +def on_message(ws, message): + #输出文件 + global outfile + try: + message =json.loads(message) + code = message["code"] + sid = message["sid"] + audio = message["data"]["audio"] + audio = base64.b64decode(audio) + status = message["data"]["status"] + if status == 2: + print("ws is closed") + ws.close() + if code != 0: + errMsg = message["message"] + print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) + else: + + with open(outfile, 'ab') as f: + f.write(audio) + + except Exception as e: + print("receive msg,but parse exception:", e) + + + +# 收到websocket连接建立的处理 +def on_open(ws): + global outfile + global wsParam + def run(*args): + d = {"common": wsParam.CommonArgs, + "business": wsParam.BusinessArgs, + "data": wsParam.Data, + } + d = json.dumps(d) + # print("------>开始发送文本数据") + ws.send(d) + if os.path.exists(outfile): + os.remove(outfile) + + thread.start_new_thread(run, ()) + +# 收到websocket错误的处理 +def on_error(ws, error): + print("### error:", error) + + + +# 收到websocket关闭的处理 +def on_close(ws): + print("### closed ###") + + + +def xunfei_tts(APPID, APIKey, APISecret,BusinessArgsTTS, Text, OutFile): + global outfile + global wsParam + outfile = OutFile + wsParam1 = Ws_Param(APPID,APIKey,APISecret,BusinessArgsTTS,Text) + wsParam = wsParam1 + websocket.enableTrace(False) + wsUrl = wsParam.create_url() + ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close) + ws.on_open = on_open + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) + return outfile + diff --git a/voice/xunfei/xunfei_voice.py b/voice/xunfei/xunfei_voice.py new file mode 100644 index 000000000..7b503f7a6 --- /dev/null +++ b/voice/xunfei/xunfei_voice.py @@ -0,0 +1,86 @@ +##################################################################### +# xunfei voice service +# Auth: njnuko +# Email: njnuko@163.com +# +# 要使用本模块, 首先到 xfyun.cn 注册一个开发者账号, +# 之后创建一个新应用, 然后在应用管理的语音识别或者语音合同右边可以查看APPID API Key 和 Secret Key +# 然后在 config.json 中填入这三个值 +# +# 配置说明: +# { +# "APPID":"xxx71xxx", +# "APIKey":"xxxx69058exxxxxx", #讯飞xfyun.cn控制台语音合成或者听写界面的APIKey +# "APISecret":"xxxx697f0xxxxxx", #讯飞xfyun.cn控制台语音合成或者听写界面的APIKey +# "BusinessArgsTTS":{"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"}, #语音合成的参数,具体可以参考xfyun.cn的文档 +# "BusinessArgsASR":{"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vad_eos":10000, "dwa": "wpgs"} #语音听写的参数,具体可以参考xfyun.cn的文档 +# } +##################################################################### + +import json +import os +import time + +from bridge.reply import Reply, ReplyType +from common.log import logger +from common.tmp_dir import TmpDir +from config import conf +from voice.voice import Voice +from .xunfei_asr import xunfei_asr +from .xunfei_tts import xunfei_tts +from voice.audio_convert import any_to_mp3 +import shutil +from pydub import AudioSegment + + +class XunfeiVoice(Voice): + def __init__(self): + try: + curdir = os.path.dirname(__file__) + config_path = os.path.join(curdir, "config.json") + conf = None + with open(config_path, "r") as fr: + conf = json.load(fr) + print(conf) + self.APPID = str(conf.get("APPID")) + self.APIKey = str(conf.get("APIKey")) + self.APISecret = str(conf.get("APISecret")) + self.BusinessArgsTTS = conf.get("BusinessArgsTTS") + self.BusinessArgsASR= conf.get("BusinessArgsASR") + + except Exception as e: + logger.warn("XunfeiVoice init failed: %s, ignore " % e) + + def voiceToText(self, voice_file): + # 识别本地文件 + try: + logger.debug("[Xunfei] voice file name={}".format(voice_file)) + #print("voice_file===========",voice_file) + #print("voice_file_type===========",type(voice_file)) + #mp3_name, file_extension = os.path.splitext(voice_file) + #mp3_file = mp3_name + ".mp3" + #pcm_data=get_pcm_from_wav(voice_file) + #mp3_name, file_extension = os.path.splitext(voice_file) + #AudioSegment.from_wav(voice_file).export(mp3_file, format="mp3") + #shutil.copy2(voice_file, 'tmp/test1.wav') + #shutil.copy2(mp3_file, 'tmp/test1.mp3') + #print("voice and mp3 file",voice_file,mp3_file) + text = xunfei_asr(self.APPID,self.APISecret,self.APIKey,self.BusinessArgsASR,voice_file) + logger.info("讯飞语音识别到了: {}".format(text)) + reply = Reply(ReplyType.TEXT, text) + except Exception as e: + logger.warn("XunfeiVoice init failed: %s, ignore " % e) + reply = Reply(ReplyType.ERROR, "讯飞语音识别出错了;{0}") + return reply + + def textToVoice(self, text): + try: + # Avoid the same filename under multithreading + fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3" + return_file = xunfei_tts(self.APPID,self.APIKey,self.APISecret,self.BusinessArgsTTS,text,fileName) + logger.info("[Xunfei] textToVoice text={} voice file name={}".format(text, fileName)) + reply = Reply(ReplyType.VOICE, fileName) + except Exception as e: + logger.error("[Xunfei] textToVoice error={}".format(fileName)) + reply = Reply(ReplyType.ERROR, "抱歉,讯飞语音合成失败") + return reply