265 lines
6.0 KiB
Go
265 lines
6.0 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
"golang.org/x/net/http2"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"time"
|
|
)
|
|
|
|
type config struct {
|
|
Bind string `json:"bind"`
|
|
ProxyUrl string `json:"proxy_url"`
|
|
Timeout int `json:"timeout"`
|
|
CodexApiBase string `json:"codex_api_base"`
|
|
CodexApiKey string `json:"codex_api_key"`
|
|
CodexApiOrganization string `json:"codex_api_organization"`
|
|
CodexApiProject string `json:"codex_api_project"`
|
|
ChatApiBase string `json:"chat_api_base"`
|
|
ChatApiKey string `json:"chat_api_key"`
|
|
ChatApiOrganization string `json:"chat_api_organization"`
|
|
ChatApiProject string `json:"chat_api_project"`
|
|
ChatModelDefault string `json:"chat_model_default"`
|
|
ChatModelMap map[string]string `json:"chat_model_map"`
|
|
}
|
|
|
|
func readConfig() *config {
|
|
content, err := os.ReadFile("config.json")
|
|
if nil != err {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
_cfg := &config{}
|
|
err = json.Unmarshal(content, &_cfg)
|
|
if nil != err {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return _cfg
|
|
}
|
|
|
|
func getClient(cfg *config) (*http.Client, error) {
|
|
transport := &http.Transport{
|
|
ForceAttemptHTTP2: true,
|
|
DisableKeepAlives: false,
|
|
}
|
|
|
|
err := http2.ConfigureTransport(transport)
|
|
if nil != err {
|
|
return nil, err
|
|
}
|
|
|
|
if "" != cfg.ProxyUrl {
|
|
proxyUrl, err := url.Parse(cfg.ProxyUrl)
|
|
if nil != err {
|
|
return nil, err
|
|
}
|
|
|
|
transport.Proxy = http.ProxyURL(proxyUrl)
|
|
}
|
|
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
Timeout: time.Duration(cfg.Timeout) * time.Second,
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func abortCodex(c *gin.Context, status int) {
|
|
c.Header("Content-Type", "text/event-stream")
|
|
|
|
c.String(status, "data: [DONE]\n")
|
|
c.Abort()
|
|
}
|
|
|
|
func closeIO(c io.Closer) {
|
|
err := c.Close()
|
|
if nil != err {
|
|
log.Println(err)
|
|
}
|
|
}
|
|
|
|
type ProxyService struct {
|
|
cfg *config
|
|
client *http.Client
|
|
}
|
|
|
|
func NewProxyService(cfg *config) (*ProxyService, error) {
|
|
client, err := getClient(cfg)
|
|
if nil != err {
|
|
return nil, err
|
|
}
|
|
|
|
return &ProxyService{
|
|
cfg: cfg,
|
|
client: client,
|
|
}, nil
|
|
}
|
|
|
|
func (s *ProxyService) InitRoutes(e *gin.Engine) {
|
|
e.POST("/v1/chat/completions", s.completions)
|
|
e.POST("/v1/engines/copilot-codex/completions", s.codeCompletions)
|
|
}
|
|
|
|
func (s *ProxyService) completions(c *gin.Context) {
|
|
ctx := c.Request.Context()
|
|
|
|
body, err := io.ReadAll(c.Request.Body)
|
|
if nil != err {
|
|
c.AbortWithStatus(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
model := gjson.GetBytes(body, "model").String()
|
|
if mapped, ok := s.cfg.ChatModelMap[model]; ok {
|
|
model = mapped
|
|
} else {
|
|
model = s.cfg.ChatModelDefault
|
|
}
|
|
body, _ = sjson.SetBytes(body, "model", model)
|
|
body, _ = sjson.DeleteBytes(body, "intent")
|
|
|
|
proxyUrl := s.cfg.ChatApiBase + "/chat/completions"
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
|
|
if nil != err {
|
|
c.AbortWithStatus(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+s.cfg.ChatApiKey)
|
|
if "" != s.cfg.ChatApiOrganization {
|
|
req.Header.Set("OpenAI-Organization", s.cfg.ChatApiOrganization)
|
|
}
|
|
if "" != s.cfg.ChatApiProject {
|
|
req.Header.Set("OpenAI-Project", s.cfg.ChatApiProject)
|
|
}
|
|
|
|
resp, err := s.client.Do(req)
|
|
if nil != err {
|
|
if errors.Is(err, context.Canceled) {
|
|
c.AbortWithStatus(http.StatusRequestTimeout)
|
|
return
|
|
}
|
|
|
|
log.Println("request conversation failed:", err.Error())
|
|
c.AbortWithStatus(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer closeIO(resp.Body)
|
|
|
|
if resp.StatusCode != http.StatusOK { // log
|
|
body, _ := io.ReadAll(resp.Body)
|
|
log.Println("request completions failed:", string(body))
|
|
|
|
resp.Body = io.NopCloser(bytes.NewBuffer(body))
|
|
}
|
|
|
|
c.Status(resp.StatusCode)
|
|
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if "" != contentType {
|
|
c.Header("Content-Type", contentType)
|
|
}
|
|
|
|
_, _ = io.Copy(c.Writer, resp.Body)
|
|
}
|
|
|
|
func (s *ProxyService) codeCompletions(c *gin.Context) {
|
|
ctx := c.Request.Context()
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
if ctx.Err() != nil {
|
|
abortCodex(c, http.StatusRequestTimeout)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(c.Request.Body)
|
|
if nil != err {
|
|
abortCodex(c, http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
body, _ = sjson.DeleteBytes(body, "extra")
|
|
body, _ = sjson.DeleteBytes(body, "nwo")
|
|
body, _ = sjson.SetBytes(body, "model", "gpt-3.5-turbo-instruct")
|
|
|
|
proxyUrl := s.cfg.CodexApiBase + "/completions"
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, proxyUrl, io.NopCloser(bytes.NewBuffer(body)))
|
|
if nil != err {
|
|
abortCodex(c, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+s.cfg.CodexApiKey)
|
|
if "" != s.cfg.CodexApiOrganization {
|
|
req.Header.Set("OpenAI-Organization", s.cfg.CodexApiOrganization)
|
|
}
|
|
if "" != s.cfg.CodexApiProject {
|
|
req.Header.Set("OpenAI-Project", s.cfg.CodexApiProject)
|
|
}
|
|
|
|
resp, err := s.client.Do(req)
|
|
if nil != err {
|
|
if errors.Is(err, context.Canceled) {
|
|
abortCodex(c, http.StatusRequestTimeout)
|
|
return
|
|
}
|
|
|
|
log.Println("request completions failed:", err.Error())
|
|
abortCodex(c, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer closeIO(resp.Body)
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
log.Println("request completions failed:", string(body))
|
|
|
|
abortCodex(c, resp.StatusCode)
|
|
return
|
|
}
|
|
|
|
c.Status(resp.StatusCode)
|
|
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if "" != contentType {
|
|
c.Header("Content-Type", contentType)
|
|
}
|
|
|
|
_, _ = io.Copy(c.Writer, resp.Body)
|
|
}
|
|
|
|
func main() {
|
|
cfg := readConfig()
|
|
|
|
gin.SetMode(gin.ReleaseMode)
|
|
r := gin.Default()
|
|
|
|
proxyService, err := NewProxyService(cfg)
|
|
if nil != err {
|
|
log.Fatal(err)
|
|
return
|
|
}
|
|
|
|
proxyService.InitRoutes(r)
|
|
|
|
err = r.Run(cfg.Bind)
|
|
if nil != err {
|
|
log.Fatal(err)
|
|
return
|
|
}
|
|
}
|