mirror of
https://github.com/siyuan-note/siyuan
synced 2026-04-21 13:37:52 +00:00
* 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters * 🎨 Improve validation of some JSON parameters
361 lines
9.7 KiB
Go
361 lines
9.7 KiB
Go
// SiYuan - Refactor your thinking
|
||
// Copyright (c) 2020-present, b3log.org
|
||
//
|
||
// This program is free software: you can redistribute it and/or modify
|
||
// it under the terms of the GNU Affero General Public License as published by
|
||
// the Free Software Foundation, either version 3 of the License, or
|
||
// (at your option) any later version.
|
||
//
|
||
// This program is distributed in the hope that it will be useful,
|
||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||
// GNU Affero General Public License for more details.
|
||
//
|
||
// You should have received a copy of the GNU Affero General Public License
|
||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||
|
||
package api
|
||
|
||
import (
|
||
"encoding/base32"
|
||
"encoding/base64"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"io"
|
||
"math"
|
||
"net"
|
||
"net/http"
|
||
"net/textproto"
|
||
"net/url"
|
||
"strings"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/88250/gulu"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/imroc/req/v3"
|
||
"github.com/siyuan-note/logging"
|
||
"github.com/siyuan-note/siyuan/kernel/util"
|
||
)
|
||
|
||
type File struct {
|
||
Filename string
|
||
Header textproto.MIMEHeader
|
||
Size int64
|
||
Content string
|
||
}
|
||
|
||
type MultipartForm struct {
|
||
Value map[string][]string
|
||
File map[string][]File
|
||
}
|
||
|
||
func echo(c *gin.Context) {
|
||
ret := gulu.Ret.NewResult()
|
||
defer c.JSON(http.StatusOK, ret)
|
||
|
||
var (
|
||
password string
|
||
passwordSet bool
|
||
multipartForm *MultipartForm
|
||
rawData any
|
||
)
|
||
|
||
password, passwordSet = c.Request.URL.User.Password()
|
||
|
||
if form, err := c.MultipartForm(); err != nil || nil == form {
|
||
multipartForm = nil
|
||
} else {
|
||
multipartForm = &MultipartForm{
|
||
Value: form.Value,
|
||
File: map[string][]File{},
|
||
}
|
||
for k, handlers := range form.File {
|
||
files := make([]File, len(handlers))
|
||
multipartForm.File[k] = files
|
||
for i, handler := range handlers {
|
||
files[i].Filename = handler.Filename
|
||
files[i].Header = handler.Header
|
||
files[i].Size = handler.Size
|
||
if file, err := handler.Open(); err != nil {
|
||
logging.LogWarnf("echo open form [%s] file [%s] error: %s", k, handler.Filename, err.Error())
|
||
} else {
|
||
content := make([]byte, handler.Size)
|
||
if _, err := file.Read(content); err != nil {
|
||
logging.LogWarnf("echo read form [%s] file [%s] error: %s", k, handler.Filename, err.Error())
|
||
} else {
|
||
files[i].Content = base64.StdEncoding.EncodeToString(content)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if data, err := c.GetRawData(); err == nil {
|
||
rawData = base64.StdEncoding.EncodeToString(data)
|
||
} else {
|
||
logging.LogWarnf("echo get raw data error: %s", err.Error())
|
||
rawData = nil
|
||
}
|
||
c.Request.ParseForm()
|
||
c.Request.ParseMultipartForm(math.MaxInt64)
|
||
|
||
ret.Data = map[string]any{
|
||
"Context": map[string]any{
|
||
"Params": c.Params,
|
||
"HandlerNames": c.HandlerNames(),
|
||
"FullPath": c.FullPath(),
|
||
"ClientIP": c.ClientIP(),
|
||
"RemoteIP": c.RemoteIP(),
|
||
"ContentType": c.ContentType(),
|
||
"IsWebsocket": c.IsWebsocket(),
|
||
"RawData": rawData,
|
||
},
|
||
"Request": map[string]any{
|
||
"Method": c.Request.Method,
|
||
"URL": c.Request.URL,
|
||
"Proto": c.Request.Proto,
|
||
"ProtoMajor": c.Request.ProtoMajor,
|
||
"ProtoMinor": c.Request.ProtoMinor,
|
||
"Header": c.Request.Header,
|
||
"ContentLength": c.Request.ContentLength,
|
||
"TransferEncoding": c.Request.TransferEncoding,
|
||
"Close": c.Request.Close,
|
||
"Host": c.Request.Host,
|
||
"Form": c.Request.Form,
|
||
"PostForm": c.Request.PostForm,
|
||
"MultipartForm": multipartForm,
|
||
"Trailer": c.Request.Trailer,
|
||
"RemoteAddr": c.Request.RemoteAddr,
|
||
"TLS": c.Request.TLS,
|
||
"UserAgent": c.Request.UserAgent(),
|
||
"Cookies": c.Request.Cookies(),
|
||
"Referer": c.Request.Referer(),
|
||
},
|
||
"URL": map[string]any{
|
||
"EscapedPath": c.Request.URL.EscapedPath(),
|
||
"EscapedFragment": c.Request.URL.EscapedFragment(),
|
||
"String": c.Request.URL.String(),
|
||
"Redacted": c.Request.URL.Redacted(),
|
||
"IsAbs": c.Request.URL.IsAbs(),
|
||
"Query": c.Request.URL.Query(),
|
||
"RequestURI": c.Request.URL.RequestURI(),
|
||
"Hostname": c.Request.URL.Hostname(),
|
||
"Port": c.Request.URL.Port(),
|
||
},
|
||
"User": map[string]any{
|
||
"Username": c.Request.URL.User.Username(),
|
||
"Password": password,
|
||
"PasswordSet": passwordSet,
|
||
"String": c.Request.URL.User.String(),
|
||
},
|
||
}
|
||
}
|
||
|
||
func forwardProxy(c *gin.Context) {
|
||
ret := gulu.Ret.NewResult()
|
||
defer c.JSON(http.StatusOK, ret)
|
||
|
||
arg, ok := util.JsonArg(c, ret)
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
var destURL string
|
||
if !util.ParseJsonArgs(arg, ret, util.BindJsonArg("url", &destURL, true, true)) {
|
||
return
|
||
}
|
||
u, e := url.ParseRequestURI(destURL)
|
||
if nil != e {
|
||
ret.Code = -1
|
||
ret.Msg = "invalid [url]"
|
||
return
|
||
}
|
||
|
||
if u.Scheme != "http" && u.Scheme != "https" {
|
||
ret.Code = -1
|
||
ret.Msg = "only http/https is allowed"
|
||
return
|
||
}
|
||
|
||
method := "POST"
|
||
if methodArg := arg["method"]; nil != methodArg {
|
||
method = strings.ToUpper(methodArg.(string))
|
||
}
|
||
timeout := 7000
|
||
if timeoutArg := arg["timeout"]; nil != timeoutArg {
|
||
timeout = int(timeoutArg.(float64))
|
||
if 1 > timeout {
|
||
timeout = 7000
|
||
}
|
||
}
|
||
|
||
client := getSafeClient(time.Duration(timeout) * time.Millisecond)
|
||
request := client.R()
|
||
if headers, ok := arg["headers"].([]any); ok {
|
||
for _, pair := range headers {
|
||
if m, ok := pair.(map[string]any); ok {
|
||
for k, v := range m {
|
||
request.SetHeader(k, fmt.Sprintf("%v", v))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
contentType := "application/json"
|
||
if contentTypeArg := arg["contentType"]; nil != contentTypeArg {
|
||
contentType = contentTypeArg.(string)
|
||
}
|
||
request.SetHeader("Content-Type", contentType)
|
||
|
||
payloadEncoding := "json"
|
||
if payloadEncodingArg := arg["payloadEncoding"]; nil != payloadEncodingArg {
|
||
payloadEncoding = payloadEncodingArg.(string)
|
||
}
|
||
|
||
switch payloadEncoding {
|
||
case "base64":
|
||
fallthrough
|
||
case "base64-std":
|
||
if payload, err := base64.StdEncoding.DecodeString(arg["payload"].(string)); err != nil {
|
||
ret.Code = -2
|
||
ret.Msg = "decode base64-std payload failed: " + err.Error()
|
||
return
|
||
} else {
|
||
request.SetBody(payload)
|
||
}
|
||
case "base64-url":
|
||
if payload, err := base64.URLEncoding.DecodeString(arg["payload"].(string)); err != nil {
|
||
ret.Code = -2
|
||
ret.Msg = "decode base64-url payload failed: " + err.Error()
|
||
return
|
||
} else {
|
||
request.SetBody(payload)
|
||
}
|
||
case "base32":
|
||
fallthrough
|
||
case "base32-std":
|
||
if payload, err := base32.StdEncoding.DecodeString(arg["payload"].(string)); err != nil {
|
||
ret.Code = -2
|
||
ret.Msg = "decode base32-std payload failed: " + err.Error()
|
||
return
|
||
} else {
|
||
request.SetBody(payload)
|
||
}
|
||
case "base32-hex":
|
||
if payload, err := base32.HexEncoding.DecodeString(arg["payload"].(string)); err != nil {
|
||
ret.Code = -2
|
||
ret.Msg = "decode base32-hex payload failed: " + err.Error()
|
||
return
|
||
} else {
|
||
request.SetBody(payload)
|
||
}
|
||
case "hex":
|
||
if payload, err := hex.DecodeString(arg["payload"].(string)); err != nil {
|
||
ret.Code = -2
|
||
ret.Msg = "decode hex payload failed: " + err.Error()
|
||
return
|
||
} else {
|
||
request.SetBody(payload)
|
||
}
|
||
case "text":
|
||
default:
|
||
request.SetBody(arg["payload"])
|
||
}
|
||
|
||
started := time.Now()
|
||
resp, err := request.Send(method, destURL)
|
||
if err != nil {
|
||
ret.Code = -1
|
||
ret.Msg = "forward request failed: " + err.Error()
|
||
return
|
||
}
|
||
|
||
bodyData, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
ret.Code = -1
|
||
ret.Msg = "read response body failed: " + err.Error()
|
||
return
|
||
}
|
||
|
||
elapsed := time.Now().Sub(started)
|
||
|
||
responseEncoding := "text"
|
||
if responseEncodingArg := arg["responseEncoding"]; nil != responseEncodingArg {
|
||
responseEncoding = responseEncodingArg.(string)
|
||
}
|
||
|
||
body := ""
|
||
switch responseEncoding {
|
||
case "base64":
|
||
fallthrough
|
||
case "base64-std":
|
||
body = base64.StdEncoding.EncodeToString(bodyData)
|
||
case "base64-url":
|
||
body = base64.URLEncoding.EncodeToString(bodyData)
|
||
case "base32":
|
||
fallthrough
|
||
case "base32-std":
|
||
body = base32.StdEncoding.EncodeToString(bodyData)
|
||
case "base32-hex":
|
||
body = base32.HexEncoding.EncodeToString(bodyData)
|
||
case "hex":
|
||
body = hex.EncodeToString(bodyData)
|
||
case "text":
|
||
fallthrough
|
||
default:
|
||
responseEncoding = "text"
|
||
body = string(bodyData)
|
||
}
|
||
|
||
data := map[string]any{
|
||
"url": destURL,
|
||
"status": resp.StatusCode,
|
||
"contentType": resp.GetHeader("content-type"),
|
||
"body": body,
|
||
"bodyEncoding": responseEncoding,
|
||
"headers": resp.Header,
|
||
"elapsed": elapsed.Milliseconds(),
|
||
}
|
||
ret.Data = data
|
||
|
||
//shortBody := ""
|
||
//if 64 > len(body) {
|
||
// shortBody = body
|
||
//} else {
|
||
// shortBody = body[:64]
|
||
//}
|
||
//
|
||
//logging.LogInfof("elapsed [%.1fs], length [%d], request [url=%s, headers=%s, content-type=%s, body=%s], status [%d], body [%s]",
|
||
// elapsed.Seconds(), len(bodyData), data["url"], headers, contentType, arg["payload"], data["status"], shortBody)
|
||
}
|
||
|
||
// 校验 IP 是否为私有内网地址
|
||
func isPrivateIP(ip net.IP) bool {
|
||
return ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsPrivate() || ip.IsUnspecified()
|
||
}
|
||
|
||
// 创建安全的 HTTP Client,防止 SSRF 和 DNS 重绑定
|
||
func getSafeClient(timeout time.Duration) *req.Client {
|
||
dialer := &net.Dialer{
|
||
Timeout: timeout,
|
||
// Control 函数在解析出 IP 后、建立连接前执行,是防御 DNS Rebinding 的关键
|
||
Control: func(network, address string, c syscall.RawConn) error {
|
||
host, _, err := net.SplitHostPort(address)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
ip := net.ParseIP(host)
|
||
if ip != nil && isPrivateIP(ip) {
|
||
return fmt.Errorf("ip address [%s] is prohibited", host)
|
||
}
|
||
return nil
|
||
},
|
||
}
|
||
|
||
client := req.C()
|
||
client.SetTimeout(timeout)
|
||
client.SetDial(dialer.DialContext)
|
||
client.SetRedirectPolicy(req.MaxRedirectPolicy(3))
|
||
return client
|
||
}
|