fallback model
This commit is contained in:
parent
8351543f62
commit
811aad5bfc
126
main.go
126
main.go
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/linux-do/tiktoken-go"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/net/http2"
|
||||
|
|
@ -14,10 +15,14 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const INSTRUCT_MODEL = "gpt-3.5-turbo-instruct"
|
||||
|
||||
type config struct {
|
||||
Bind string `json:"bind"`
|
||||
ProxyUrl string `json:"proxy_url"`
|
||||
|
|
@ -26,7 +31,7 @@ type config struct {
|
|||
CodexApiKey string `json:"codex_api_key"`
|
||||
CodexApiOrganization string `json:"codex_api_organization"`
|
||||
CodexApiProject string `json:"codex_api_project"`
|
||||
CodexModelDefault string `json:"codex_model_default"`
|
||||
CodexMaxTokens int `json:"codex_max_tokens"`
|
||||
ChatApiBase string `json:"chat_api_base"`
|
||||
ChatApiKey string `json:"chat_api_key"`
|
||||
ChatApiOrganization string `json:"chat_api_organization"`
|
||||
|
|
@ -47,35 +52,46 @@ func readConfig() *config {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(_cfg).Elem()
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
tag := t.Field(i).Tag.Get("json")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
value, exists := os.LookupEnv("OVERRIDE_" + strings.ToUpper(tag))
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString(value)
|
||||
case reflect.Bool:
|
||||
if boolValue, err := strconv.ParseBool(value); err == nil {
|
||||
field.SetBool(boolValue)
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if intValue, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||
field.SetInt(intValue)
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
if uintValue, err := strconv.ParseUint(value, 10, 64); err == nil {
|
||||
field.SetUint(uintValue)
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if floatValue, err := strconv.ParseFloat(value, field.Type().Bits()); err == nil {
|
||||
field.SetFloat(floatValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return _cfg
|
||||
}
|
||||
|
||||
type GPTMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
type StreamResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
type Message struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content any `json:"content,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta Message `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
func getClient(cfg *config) (*http.Client, error) {
|
||||
transport := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
|
|
@ -119,8 +135,9 @@ func closeIO(c io.Closer) {
|
|||
}
|
||||
|
||||
type ProxyService struct {
|
||||
cfg *config
|
||||
client *http.Client
|
||||
cfg *config
|
||||
client *http.Client
|
||||
tokenizer *tiktoken.Tiktoken
|
||||
}
|
||||
|
||||
func NewProxyService(cfg *config) (*ProxyService, error) {
|
||||
|
|
@ -129,9 +146,15 @@ func NewProxyService(cfg *config) (*ProxyService, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
tokenizer, err := tiktoken.EncodingForModel(INSTRUCT_MODEL)
|
||||
if nil != err {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ProxyService{
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
tokenizer: tokenizer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -219,24 +242,35 @@ func (s *ProxyService) codeCompletions(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
prompt := gjson.GetBytes(body, "prompt").String()
|
||||
suffix := gjson.GetBytes(body, "suffix").String()
|
||||
inputTokens := len(s.tokenizer.Encode(prompt, nil, nil))
|
||||
suffixTokens := len(s.tokenizer.Encode(suffix, nil, nil))
|
||||
outputTokens := int(gjson.GetBytes(body, "max_tokens").Int())
|
||||
|
||||
totalTokens := inputTokens + suffixTokens + outputTokens
|
||||
if totalTokens > s.cfg.CodexMaxTokens { // reduce
|
||||
left, right := 0, len(prompt)
|
||||
for left < right {
|
||||
mid := (left + right) / 2
|
||||
subPrompt := prompt[mid:]
|
||||
subInputTokens := len(s.tokenizer.Encode(subPrompt, nil, nil))
|
||||
totalTokens = subInputTokens + suffixTokens + outputTokens
|
||||
if totalTokens > s.cfg.CodexMaxTokens {
|
||||
left = mid + 1
|
||||
} else {
|
||||
right = mid
|
||||
}
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, "prompt", prompt[left:])
|
||||
}
|
||||
|
||||
body, _ = sjson.DeleteBytes(body, "extra")
|
||||
body, _ = sjson.DeleteBytes(body, "nwo")
|
||||
if s.cfg.CodexModelDefault == "" {
|
||||
s.cfg.CodexModelDefault = "gpt-3.5-turbo-instruct"
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", s.cfg.CodexModelDefault)
|
||||
body, _ = sjson.SetBytes(body, "model", INSTRUCT_MODEL)
|
||||
|
||||
proxyUrl := s.cfg.CodexApiBase
|
||||
if strings.HasPrefix(s.cfg.CodexModelDefault, "@") {
|
||||
proxyUrl = s.cfg.CodexApiBase
|
||||
message := gjson.GetBytes(body, "prompt").String()
|
||||
body, _ = sjson.DeleteBytes(body, "prompt")
|
||||
msg := make([]GPTMessage, 0)
|
||||
msg = append(msg, GPTMessage{Role: "system", Content: "You are a helpful assistant"})
|
||||
msg = append(msg, GPTMessage{Role: "user", Content: message})
|
||||
body, _ = sjson.SetBytes(body, "messages", msg)
|
||||
body, _ = sjson.DeleteBytes(body, "n")
|
||||
}
|
||||
proxyUrl := s.cfg.CodexApiBase + "/completions"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
|
||||
if nil != err {
|
||||
abortCodex(c, http.StatusInternalServerError)
|
||||
|
|
|
|||
Loading…
Reference in New Issue