cloudflare 支持
This commit is contained in:
parent
c6f0d41f33
commit
900056622b
85
main.go
85
main.go
|
|
@ -6,7 +6,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/linux-do/tiktoken-go"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/net/http2"
|
||||
|
|
@ -15,14 +14,9 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct"
|
||||
|
||||
type config struct {
|
||||
Bind string `json:"bind"`
|
||||
ProxyUrl string `json:"proxy_url"`
|
||||
|
|
@ -31,7 +25,6 @@ type config struct {
|
|||
CodexApiKey string `json:"codex_api_key"`
|
||||
CodexApiOrganization string `json:"codex_api_organization"`
|
||||
CodexApiProject string `json:"codex_api_project"`
|
||||
CodexMaxTokens int `json:"codex_max_tokens"`
|
||||
ChatApiBase string `json:"chat_api_base"`
|
||||
ChatApiKey string `json:"chat_api_key"`
|
||||
ChatApiOrganization string `json:"chat_api_organization"`
|
||||
|
|
@ -52,43 +45,6 @@ func readConfig() *config {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(_cfg).Elem()
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
tag := t.Field(i).Tag.Get("json")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
value, exists := os.LookupEnv("OVERRIDE_" + strings.ToUpper(tag))
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString(value)
|
||||
case reflect.Bool:
|
||||
if boolValue, err := strconv.ParseBool(value); err == nil {
|
||||
field.SetBool(boolValue)
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if intValue, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
field.SetInt(intValue)
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
if uintValue, err := strconv.ParseUint(value, 10, 64); err == nil {
|
||||
field.SetUint(uintValue)
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if floatValue, err := strconv.ParseFloat(value, field.Type().Bits()); err == nil {
|
||||
field.SetFloat(floatValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return _cfg
|
||||
}
|
||||
|
||||
|
|
@ -135,9 +91,8 @@ func closeIO(c io.Closer) {
|
|||
}
|
||||
|
||||
type ProxyService struct {
|
||||
cfg *config
|
||||
client *http.Client
|
||||
tokenizer *tiktoken.Tiktoken
|
||||
cfg *config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewProxyService(cfg *config) (*ProxyService, error) {
|
||||
|
|
@ -146,15 +101,9 @@ func NewProxyService(cfg *config) (*ProxyService, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL)
|
||||
if nil != err {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ProxyService{
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
tokenizer: tokenizer,
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -242,33 +191,9 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
prompt := gjson.GetBytes(body, "prompt").String()
|
||||
suffix := gjson.GetBytes(body, "suffix").String()
|
||||
inputTokens := len(s.tokenizer.Encode(prompt, nil, nil))
|
||||
suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil))
|
||||
outputTokens := int(gjson.GetBytes(body, "max_tokens").Int())
|
||||
|
||||
totalTokens := inputTokens + suffixTokens + outputTokens
|
||||
if totalTokens > s.cfg.CodexMaxTokens { // reduce
|
||||
left, right := 0, len(prompt)
|
||||
for left < right {
|
||||
mid := (left + right) / 2
|
||||
subPrompt := prompt[mid:]
|
||||
subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil))
|
||||
totalTokens = subInputTokens + suffixTokens + outputTokens
|
||||
if totalTokens > s.cfg.CodexMaxTokens {
|
||||
left = mid + 1
|
||||
} else {
|
||||
right = mid
|
||||
}
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, "prompt", prompt[left:])
|
||||
}
|
||||
|
||||
body, _ = sjson.DeleteBytes(body, "extra")
|
||||
body, _ = sjson.DeleteBytes(body, "nwo")
|
||||
body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL)
|
||||
body, _ = sjson.SetBytes(body, "model", "gpt-3.5-turbo-instruct")
|
||||
|
||||
proxyUrl := s.cfg.CodexApiBase + "/completions"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
|
||||
|
|
|
|||
Loading…
Reference in New Issue