[go: nahoru, domu]

Skip to content

Commit

Permalink
feat: refactor AwsClaude to Aws to support both llama3 and claude (#1601
Browse files Browse the repository at this point in the history
)

* feat: refactor AwsClaude to Aws to support both llama3 and claude

* fix: aws llama3 ratio
  • Loading branch information
WqyJh committed Jul 6, 2024
1 parent e090e76 commit 720fe2d
Show file tree
Hide file tree
Showing 18 changed files with 595 additions and 88 deletions.
78 changes: 40 additions & 38 deletions relay/adaptor/aws/adapter.go → relay/adaptor/aws/adaptor.go
Original file line number Diff line number Diff line change
@@ -1,82 +1,84 @@
package aws

import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"
"errors"
"io"
"net/http"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)

var _ adaptor.Adaptor = new(Adaptor)

type Adaptor struct {
meta *meta.Meta
awsClient *bedrockruntime.Client
awsAdapter utils.AwsAdapter

Meta *meta.Meta
AwsClient *bedrockruntime.Client
}

func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
a.awsClient = bedrockruntime.New(bedrockruntime.Options{
a.Meta = meta
a.AwsClient = bedrockruntime.New(bedrockruntime.Options{
Region: meta.Config.Region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")),
})
}

func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}

func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}

func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}

func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
adaptor := GetAdaptor(request.Model)
if adaptor == nil {
return nil, errors.New("adaptor not found")
}
return request, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
a.awsAdapter = adaptor
return adaptor.ConvertRequest(c, relayMode, request)
}

func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, a.awsClient)
} else {
err, usage = Handler(c, a.awsClient, meta.ActualModelName)
if a.awsAdapter == nil {
return nil, utils.WrapErr(errors.New("awsAdapter is nil"))
}
return
return a.awsAdapter.DoResponse(c, a.AwsClient, meta)
}

func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
for model := range adaptors {
models = append(models, model)
}
return
}

func (a *Adaptor) GetChannelName() string {
return "aws"
}

func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil
}

func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
return nil
}

func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
37 changes: 37 additions & 0 deletions relay/adaptor/aws/claude/adapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package aws

import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)

var _ utils.AwsAdapter = new(Adaptor)

type Adaptor struct {
}

func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

claudeReq := anthropic.ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
}

func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}
40 changes: 16 additions & 24 deletions relay/adaptor/aws/main.go → relay/adaptor/aws/claude/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"io"
"net/http"

Expand All @@ -17,23 +15,17 @@ import (
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)

func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
return &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}

// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var awsModelIDMap = map[string]string{
var AwsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
Expand All @@ -44,7 +36,7 @@ var awsModelIDMap = map[string]string{
}

func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
if awsModelID, ok := AwsModelIDMap[requestModel]; ok {
return awsModelID, nil
}

Expand All @@ -54,7 +46,7 @@ func awsModelID(requestModel string) (string, error) {
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}

awsReq := &bedrockruntime.InvokeModelInput{
Expand All @@ -65,30 +57,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*

claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
return utils.WrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*anthropic.Request)
awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
return utils.WrapErr(errors.Wrap(err, "copy request")), nil
}

awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}

awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil
}

claudeResponse := new(anthropic.Response)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
}

openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
Expand All @@ -108,7 +100,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}

awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
Expand All @@ -119,24 +111,24 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E

claudeReq_, ok := c.Get(ctxkey.ConvertedRequest)
if !ok {
return wrapErr(errors.New("request not found")), nil
return utils.WrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*anthropic.Request)

awsClaudeReq := &Request{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
return utils.WrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
return utils.WrapErr(errors.Wrap(err, "marshal request")), nil
}

awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
Expand Down
File renamed without changes.
37 changes: 37 additions & 0 deletions relay/adaptor/aws/llama3/adapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package aws

import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/songquanpeng/one-api/common/ctxkey"

"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor/aws/utils"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)

var _ utils.AwsAdapter = new(Adaptor)

type Adaptor struct {
}

func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

llamaReq := ConvertRequest(*request)
c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, llamaReq)
return llamaReq, nil
}

func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, awsCli)
} else {
err, usage = Handler(c, awsCli, meta.ActualModelName)
}
return
}
Loading

0 comments on commit 720fe2d

Please sign in to comment.