cloudflare 支持

This commit is contained in:
tkisme 2024-05-17 14:46:43 +08:00
parent c6f0d41f33
commit 900056622b
1 changed files with 5 additions and 80 deletions

85
main.go
View File

@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"github.com/gin-gonic/gin"
"github.com/linux-do/tiktoken-go"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/net/http2"
@ -15,14 +14,9 @@ import (
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"time"
)
const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct"
type config struct {
Bind string `json:"bind"`
ProxyUrl string `json:"proxy_url"`
@ -31,7 +25,6 @@ type config struct {
CodexApiKey string `json:"codex_api_key"`
CodexApiOrganization string `json:"codex_api_organization"`
CodexApiProject string `json:"codex_api_project"`
CodexMaxTokens int `json:"codex_max_tokens"`
ChatApiBase string `json:"chat_api_base"`
ChatApiKey string `json:"chat_api_key"`
ChatApiOrganization string `json:"chat_api_organization"`
@ -52,43 +45,6 @@ func readConfig() *config {
log.Fatal(err)
}
v := reflect.ValueOf(_cfg).Elem()
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
tag := t.Field(i).Tag.Get("json")
if tag == "" {
continue
}
value, exists := os.LookupEnv("OVERRIDE_" + strings.ToUpper(tag))
if !exists {
continue
}
switch field.Kind() {
case reflect.String:
field.SetString(value)
case reflect.Bool:
if boolValue, err := strconv.ParseBool(value); err == nil {
field.SetBool(boolValue)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if intValue, err := strconv.ParseInt(value, 10, 64); err == nil {
field.SetInt(intValue)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if uintValue, err := strconv.ParseUint(value, 10, 64); err == nil {
field.SetUint(uintValue)
}
case reflect.Float32, reflect.Float64:
if floatValue, err := strconv.ParseFloat(value, field.Type().Bits()); err == nil {
field.SetFloat(floatValue)
}
}
}
return _cfg
}
@ -135,9 +91,8 @@ func closeIO(c io.Closer) {
}
type ProxyService struct {
cfg *config
client *http.Client
tokenizer *tiktoken.Tiktoken
cfg *config
client *http.Client
}
func NewProxyService(cfg *config) (*ProxyService, error) {
@ -146,15 +101,9 @@ func NewProxyService(cfg *config) (*ProxyService, error) {
return nil, err
}
tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL)
if nil != err {
return nil, err
}
return &ProxyService{
cfg: cfg,
client: client,
tokenizer: tokenizer,
cfg: cfg,
client: client,
}, nil
}
@ -242,33 +191,9 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
return
}
prompt := gjson.GetBytes(body, "prompt").String()
suffix := gjson.GetBytes(body, "suffix").String()
inputTokens := len(s.tokenizer.Encode(prompt, nil, nil))
suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil))
outputTokens := int(gjson.GetBytes(body, "max_tokens").Int())
totalTokens := inputTokens + suffixTokens + outputTokens
if totalTokens > s.cfg.CodexMaxTokens { // reduce
left, right := 0, len(prompt)
for left < right {
mid := (left + right) / 2
subPrompt := prompt[mid:]
subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil))
totalTokens = subInputTokens + suffixTokens + outputTokens
if totalTokens > s.cfg.CodexMaxTokens {
left = mid + 1
} else {
right = mid
}
}
body, _ = sjson.SetBytes(body, "prompt", prompt[left:])
}
body, _ = sjson.DeleteBytes(body, "extra")
body, _ = sjson.DeleteBytes(body, "nwo")
body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL)
body, _ = sjson.SetBytes(body, "model", "gpt-3.5-turbo-instruct")
proxyUrl := s.cfg.CodexApiBase + "/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))