package main import ( "bytes" "context" "encoding/json" "errors" "github.com/gin-gonic/gin" "github.com/linux-do/tiktoken-go" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "golang.org/x/net/http2" "io" "log" "net/http" "net/url" "os" "reflect" "strconv" "strings" "time" ) const InstructModel = "gpt-3.5-turbo-instruct" type config struct { Bind string `json:"bind"` ProxyUrl string `json:"proxy_url"` Timeout int `json:"timeout"` CodexApiBase string `json:"codex_api_base"` CodexApiKey string `json:"codex_api_key"` CodexApiOrganization string `json:"codex_api_organization"` CodexApiProject string `json:"codex_api_project"` CodexMaxTokens int `json:"codex_max_tokens"` ChatApiBase string `json:"chat_api_base"` ChatApiKey string `json:"chat_api_key"` ChatApiOrganization string `json:"chat_api_organization"` ChatApiProject string `json:"chat_api_project"` ChatMaxTokens int `json:"chat_max_tokens"` ChatModelDefault string `json:"chat_model_default"` ChatModelMap map[string]string `json:"chat_model_map"` } func readConfig() *config { content, err := os.ReadFile("config.json") if nil != err { log.Fatal(err) } _cfg := &config{} err = json.Unmarshal(content, &_cfg) if nil != err { log.Fatal(err) } v := reflect.ValueOf(_cfg).Elem() t := v.Type() for i := 0; i < v.NumField(); i++ { field := v.Field(i) tag := t.Field(i).Tag.Get("json") if tag == "" { continue } value, exists := os.LookupEnv("OVERRIDE_" + strings.ToUpper(tag)) if !exists { continue } switch field.Kind() { case reflect.String: field.SetString(value) case reflect.Bool: if boolValue, err := strconv.ParseBool(value); err == nil { field.SetBool(boolValue) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if intValue, err := strconv.ParseInt(value, 10, 64); err == nil { field.SetInt(intValue) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if uintValue, err := strconv.ParseUint(value, 10, 64); err == nil { field.SetUint(uintValue) } case reflect.Float32, reflect.Float64: if floatValue, err := strconv.ParseFloat(value, field.Type().Bits()); err == nil { field.SetFloat(floatValue) } } } return _cfg } func getClient(cfg *config) (*http.Client, error) { transport := &http.Transport{ ForceAttemptHTTP2: true, DisableKeepAlives: false, } err := http2.ConfigureTransport(transport) if nil != err { return nil, err } if "" != cfg.ProxyUrl { proxyUrl, err := url.Parse(cfg.ProxyUrl) if nil != err { return nil, err } transport.Proxy = http.ProxyURL(proxyUrl) } client := &http.Client{ Transport: transport, Timeout: time.Duration(cfg.Timeout) * time.Second, } return client, nil } func abortCodex(c *gin.Context, status int) { c.Header("Content-Type", "text/event-stream") c.String(status, "data: [DONE]\n") c.Abort() } func closeIO(c io.Closer) { err := c.Close() if nil != err { log.Println(err) } } type ProxyService struct { cfg *config client *http.Client tokenizer *tiktoken.Tiktoken } func NewProxyService(cfg *config) (*ProxyService, error) { client, err := getClient(cfg) if nil != err { return nil, err } tokenizer, err := tiktoken.EncodingForModel(InstructModel) if nil != err { return nil, err } return &ProxyService{ cfg: cfg, client: client, tokenizer: tokenizer, }, nil } func (s *ProxyService) InitRoutes(e *gin.Engine) { e.POST("/v1/chat/completions", s.completions) e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions) } func (s *ProxyService) completions(c *gin.Context) { ctx := c.Request.Context() body, err := io.ReadAll(c.Request.Body) if nil != err { c.AbortWithStatus(http.StatusBadRequest) return } model := gjson.GetBytes(body, "model").String() if mapped, ok := s.cfg.ChatModelMap[model]; ok { model = mapped } else { model = s.cfg.ChatModelDefault } body, _ = sjson.SetBytes(body, "model", model) body, _ = sjson.DeleteBytes(body, "intent") body, _ = sjson.DeleteBytes(body, "intent_threshold") body, _ = sjson.DeleteBytes(body, "intent_content") if int(gjson.GetBytes(body, "max_tokens").Int()) > s.cfg.ChatMaxTokens { body, _ = sjson.SetBytes(body, "max_tokens", s.cfg.ChatMaxTokens) } proxyUrl := s.cfg.ChatApiBase + "/chat/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) if nil != err { c.AbortWithStatus(http.StatusInternalServerError) return } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+s.cfg.ChatApiKey) if "" != s.cfg.ChatApiOrganization { req.Header.Set("OpenAI-Organization", s.cfg.ChatApiOrganization) } if "" != s.cfg.ChatApiProject { req.Header.Set("OpenAI-Project", s.cfg.ChatApiProject) } resp, err := s.client.Do(req) if nil != err { if errors.Is(err, context.Canceled) { c.AbortWithStatus(http.StatusRequestTimeout) return } log.Println("request conversation failed:", err.Error()) c.AbortWithStatus(http.StatusInternalServerError) return } defer closeIO(resp.Body) if resp.StatusCode != http.StatusOK { // log body, _ := io.ReadAll(resp.Body) log.Println("request completions failed:", string(body)) resp.Body = io.NopCloser(bytes.NewBuffer(body)) } c.Status(resp.StatusCode) contentType := resp.Header.Get("Content-Type") if "" != contentType { c.Header("Content-Type", contentType) } _, _ = io.Copy(c.Writer, resp.Body) } func (s *ProxyService) countToken(token string) int { if "" == token { return 0 } return len(s.tokenizer.Encode(token, nil, nil)) } func (s *ProxyService) codeCompletions(c *gin.Context) { ctx := c.Request.Context() time.Sleep(100 * time.Millisecond) if ctx.Err() != nil { abortCodex(c, http.StatusRequestTimeout) return } body, err := io.ReadAll(c.Request.Body) if nil != err { abortCodex(c, http.StatusBadRequest) return } prompt := gjson.GetBytes(body, "prompt").String() suffix := gjson.GetBytes(body, "suffix").String() inputTokens := s.countToken(prompt) suffixTokens := s.countToken(suffix) outputTokens := int(gjson.GetBytes(body, "max_tokens").Int()) totalTokens := inputTokens + suffixTokens + outputTokens if totalTokens > s.cfg.CodexMaxTokens { // reduce left, right := 0, len(prompt) for left < right { mid := (left + right) / 2 subPrompt := prompt[mid:] subInputTokens := s.countToken(subPrompt) totalTokens = subInputTokens + suffixTokens + outputTokens if totalTokens > s.cfg.CodexMaxTokens { left = mid + 1 } else { right = mid } } body, _ = sjson.SetBytes(body, "prompt", prompt[left:]) } body, _ = sjson.DeleteBytes(body, "extra") body, _ = sjson.DeleteBytes(body, "nwo") body, _ = sjson.SetBytes(body, "model", InstructModel) proxyUrl := s.cfg.CodexApiBase + "/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body))) if nil != err { abortCodex(c, http.StatusInternalServerError) return } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+s.cfg.CodexApiKey) if "" != s.cfg.CodexApiOrganization { req.Header.Set("OpenAI-Organization", s.cfg.CodexApiOrganization) } if "" != s.cfg.CodexApiProject { req.Header.Set("OpenAI-Project", s.cfg.CodexApiProject) } resp, err := s.client.Do(req) if nil != err { if errors.Is(err, context.Canceled) { abortCodex(c, http.StatusRequestTimeout) return } log.Println("request completions failed:", err.Error()) abortCodex(c, http.StatusInternalServerError) return } defer closeIO(resp.Body) if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) log.Println("request completions failed:", string(body)) abortCodex(c, resp.StatusCode) return } c.Status(resp.StatusCode) contentType := resp.Header.Get("Content-Type") if "" != contentType { c.Header("Content-Type", contentType) } _, _ = io.Copy(c.Writer, resp.Body) } func main() { cfg := readConfig() gin.SetMode(gin.ReleaseMode) r := gin.Default() proxyService, err := NewProxyService(cfg) if nil != err { log.Fatal(err) return } proxyService.InitRoutes(r) err = r.Run(cfg.Bind) if nil != err { log.Fatal(err) return } }