fallback model

This commit is contained in:
tkisme 2024-05-17 15:19:48 +08:00
parent 8351543f62
commit 811aad5bfc
1 changed files with 80 additions and 46 deletions

126
main.go
View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"github.com/gin-gonic/gin"
"github.com/linux-do/tiktoken-go"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/net/http2"
@ -14,10 +15,14 @@ import (
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"time"
)
const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct"
type config struct {
Bind string `json:"bind"`
ProxyUrl string `json:"proxy_url"`
@ -26,7 +31,7 @@ type config struct {
CodexApiKey string `json:"codex_api_key"`
CodexApiOrganization string `json:"codex_api_organization"`
CodexApiProject string `json:"codex_api_project"`
CodexModelDefault string `json:"codex_model_default"`
CodexMaxTokens int `json:"codex_max_tokens"`
ChatApiBase string `json:"chat_api_base"`
ChatApiKey string `json:"chat_api_key"`
ChatApiOrganization string `json:"chat_api_organization"`
@ -47,35 +52,46 @@ func readConfig() *config {
log.Fatal(err)
}
v := reflect.ValueOf(_cfg).Elem()
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
tag := t.Field(i).Tag.Get("json")
if tag == "" {
continue
}
value, exists := os.LookupEnv("OVERRIDE_" + strings.ToUpper(tag))
if !exists {
continue
}
switch field.Kind() {
case reflect.String:
field.SetString(value)
case reflect.Bool:
if boolValue, err := strconv.ParseBool(value); err == nil {
field.SetBool(boolValue)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if intValue, err := strconv.ParseInt(value, 10, 64); err == nil {
field.SetInt(intValue)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if uintValue, err := strconv.ParseUint(value, 10, 64); err == nil {
field.SetUint(uintValue)
}
case reflect.Float32, reflect.Float64:
if floatValue, err := strconv.ParseFloat(value, field.Type().Bits()); err == nil {
field.SetFloat(floatValue)
}
}
}
return _cfg
}
type GPTMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type StreamResponse struct {
Response string `json:"response"`
}
type Message struct {
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"`
}
type ChatCompletionsStreamResponseChoice struct {
Index int `json:"index"`
Delta Message `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
func getClient(cfg *config) (*http.Client, error) {
transport := &http.Transport{
ForceAttemptHTTP2: true,
@ -119,8 +135,9 @@ func closeIO(c io.Closer) {
}
type ProxyService struct {
cfg *config
client *http.Client
cfg *config
client *http.Client
tokenizer *tiktoken.Tiktoken
}
func NewProxyService(cfg *config) (*ProxyService, error) {
@ -129,9 +146,15 @@ func NewProxyService(cfg *config) (*ProxyService, error) {
return nil, err
}
tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL)
if nil != err {
return nil, err
}
return &ProxyService{
cfg: cfg,
client: client,
cfg: cfg,
client: client,
tokenizer: tokenizer,
}, nil
}
@ -219,24 +242,35 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
return
}
prompt := gjson.GetBytes(body, "prompt").String()
suffix := gjson.GetBytes(body, "suffix").String()
inputTokens := len(s.tokenizer.Encode(prompt, nil, nil))
suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil))
outputTokens := int(gjson.GetBytes(body, "max_tokens").Int())
totalTokens := inputTokens + suffixTokens + outputTokens
if totalTokens > s.cfg.CodexMaxTokens { // reduce
left, right := 0, len(prompt)
for left < right {
mid := (left + right) / 2
subPrompt := prompt[mid:]
subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil))
totalTokens = subInputTokens + suffixTokens + outputTokens
if totalTokens > s.cfg.CodexMaxTokens {
left = mid + 1
} else {
right = mid
}
}
body, _ = sjson.SetBytes(body, "prompt", prompt[left:])
}
body, _ = sjson.DeleteBytes(body, "extra")
body, _ = sjson.DeleteBytes(body, "nwo")
if s.cfg.CodexModelDefault == "" {
s.cfg.CodexModelDefault = "gpt-3.5-turbo-instruct"
}
body, _ = sjson.SetBytes(body, "model", s.cfg.CodexModelDefault)
body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL)
proxyUrl := s.cfg.CodexApiBase
if strings.HasPrefix(s.cfg.CodexModelDefault, "@") {
proxyUrl = s.cfg.CodexApiBase
message := gjson.GetBytes(body, "prompt").String()
body, _ = sjson.DeleteBytes(body, "prompt")
msg := make([]GPTMessage, 0)
msg = append(msg, GPTMessage{Role: "system", Content: "You are a helpful assistant"})
msg = append(msg, GPTMessage{Role: "user", Content: message})
body, _ = sjson.SetBytes(body, "messages", msg)
body, _ = sjson.DeleteBytes(body, "n")
}
proxyUrl := s.cfg.CodexApiBase + "/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
if nil != err {
abortCodex(c, http.StatusInternalServerError)