siyuan/kernel/api/network.go
Jeffrey Chen 698ee3d357
♻️ Improve validation of some JSON parameters (#17412)
* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters

* 🎨 Improve validation of some JSON parameters
2026-04-05 17:03:13 +08:00

361 lines
9.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// SiYuan - Refactor your thinking
// Copyright (c) 2020-present, b3log.org
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package api
import (
"encoding/base32"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"math"
"net"
"net/http"
"net/textproto"
"net/url"
"strings"
"syscall"
"time"
"github.com/88250/gulu"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/siyuan-note/logging"
"github.com/siyuan-note/siyuan/kernel/util"
)
type File struct {
Filename string
Header textproto.MIMEHeader
Size int64
Content string
}
type MultipartForm struct {
Value map[string][]string
File map[string][]File
}
func echo(c *gin.Context) {
ret := gulu.Ret.NewResult()
defer c.JSON(http.StatusOK, ret)
var (
password string
passwordSet bool
multipartForm *MultipartForm
rawData any
)
password, passwordSet = c.Request.URL.User.Password()
if form, err := c.MultipartForm(); err != nil || nil == form {
multipartForm = nil
} else {
multipartForm = &MultipartForm{
Value: form.Value,
File: map[string][]File{},
}
for k, handlers := range form.File {
files := make([]File, len(handlers))
multipartForm.File[k] = files
for i, handler := range handlers {
files[i].Filename = handler.Filename
files[i].Header = handler.Header
files[i].Size = handler.Size
if file, err := handler.Open(); err != nil {
logging.LogWarnf("echo open form [%s] file [%s] error: %s", k, handler.Filename, err.Error())
} else {
content := make([]byte, handler.Size)
if _, err := file.Read(content); err != nil {
logging.LogWarnf("echo read form [%s] file [%s] error: %s", k, handler.Filename, err.Error())
} else {
files[i].Content = base64.StdEncoding.EncodeToString(content)
}
}
}
}
}
if data, err := c.GetRawData(); err == nil {
rawData = base64.StdEncoding.EncodeToString(data)
} else {
logging.LogWarnf("echo get raw data error: %s", err.Error())
rawData = nil
}
c.Request.ParseForm()
c.Request.ParseMultipartForm(math.MaxInt64)
ret.Data = map[string]any{
"Context": map[string]any{
"Params": c.Params,
"HandlerNames": c.HandlerNames(),
"FullPath": c.FullPath(),
"ClientIP": c.ClientIP(),
"RemoteIP": c.RemoteIP(),
"ContentType": c.ContentType(),
"IsWebsocket": c.IsWebsocket(),
"RawData": rawData,
},
"Request": map[string]any{
"Method": c.Request.Method,
"URL": c.Request.URL,
"Proto": c.Request.Proto,
"ProtoMajor": c.Request.ProtoMajor,
"ProtoMinor": c.Request.ProtoMinor,
"Header": c.Request.Header,
"ContentLength": c.Request.ContentLength,
"TransferEncoding": c.Request.TransferEncoding,
"Close": c.Request.Close,
"Host": c.Request.Host,
"Form": c.Request.Form,
"PostForm": c.Request.PostForm,
"MultipartForm": multipartForm,
"Trailer": c.Request.Trailer,
"RemoteAddr": c.Request.RemoteAddr,
"TLS": c.Request.TLS,
"UserAgent": c.Request.UserAgent(),
"Cookies": c.Request.Cookies(),
"Referer": c.Request.Referer(),
},
"URL": map[string]any{
"EscapedPath": c.Request.URL.EscapedPath(),
"EscapedFragment": c.Request.URL.EscapedFragment(),
"String": c.Request.URL.String(),
"Redacted": c.Request.URL.Redacted(),
"IsAbs": c.Request.URL.IsAbs(),
"Query": c.Request.URL.Query(),
"RequestURI": c.Request.URL.RequestURI(),
"Hostname": c.Request.URL.Hostname(),
"Port": c.Request.URL.Port(),
},
"User": map[string]any{
"Username": c.Request.URL.User.Username(),
"Password": password,
"PasswordSet": passwordSet,
"String": c.Request.URL.User.String(),
},
}
}
func forwardProxy(c *gin.Context) {
ret := gulu.Ret.NewResult()
defer c.JSON(http.StatusOK, ret)
arg, ok := util.JsonArg(c, ret)
if !ok {
return
}
var destURL string
if !util.ParseJsonArgs(arg, ret, util.BindJsonArg("url", &destURL, true, true)) {
return
}
u, e := url.ParseRequestURI(destURL)
if nil != e {
ret.Code = -1
ret.Msg = "invalid [url]"
return
}
if u.Scheme != "http" && u.Scheme != "https" {
ret.Code = -1
ret.Msg = "only http/https is allowed"
return
}
method := "POST"
if methodArg := arg["method"]; nil != methodArg {
method = strings.ToUpper(methodArg.(string))
}
timeout := 7000
if timeoutArg := arg["timeout"]; nil != timeoutArg {
timeout = int(timeoutArg.(float64))
if 1 > timeout {
timeout = 7000
}
}
client := getSafeClient(time.Duration(timeout) * time.Millisecond)
request := client.R()
if headers, ok := arg["headers"].([]any); ok {
for _, pair := range headers {
if m, ok := pair.(map[string]any); ok {
for k, v := range m {
request.SetHeader(k, fmt.Sprintf("%v", v))
}
}
}
}
contentType := "application/json"
if contentTypeArg := arg["contentType"]; nil != contentTypeArg {
contentType = contentTypeArg.(string)
}
request.SetHeader("Content-Type", contentType)
payloadEncoding := "json"
if payloadEncodingArg := arg["payloadEncoding"]; nil != payloadEncodingArg {
payloadEncoding = payloadEncodingArg.(string)
}
switch payloadEncoding {
case "base64":
fallthrough
case "base64-std":
if payload, err := base64.StdEncoding.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = -2
ret.Msg = "decode base64-std payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "base64-url":
if payload, err := base64.URLEncoding.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = -2
ret.Msg = "decode base64-url payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "base32":
fallthrough
case "base32-std":
if payload, err := base32.StdEncoding.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = -2
ret.Msg = "decode base32-std payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "base32-hex":
if payload, err := base32.HexEncoding.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = -2
ret.Msg = "decode base32-hex payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "hex":
if payload, err := hex.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = -2
ret.Msg = "decode hex payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "text":
default:
request.SetBody(arg["payload"])
}
started := time.Now()
resp, err := request.Send(method, destURL)
if err != nil {
ret.Code = -1
ret.Msg = "forward request failed: " + err.Error()
return
}
bodyData, err := io.ReadAll(resp.Body)
if err != nil {
ret.Code = -1
ret.Msg = "read response body failed: " + err.Error()
return
}
elapsed := time.Now().Sub(started)
responseEncoding := "text"
if responseEncodingArg := arg["responseEncoding"]; nil != responseEncodingArg {
responseEncoding = responseEncodingArg.(string)
}
body := ""
switch responseEncoding {
case "base64":
fallthrough
case "base64-std":
body = base64.StdEncoding.EncodeToString(bodyData)
case "base64-url":
body = base64.URLEncoding.EncodeToString(bodyData)
case "base32":
fallthrough
case "base32-std":
body = base32.StdEncoding.EncodeToString(bodyData)
case "base32-hex":
body = base32.HexEncoding.EncodeToString(bodyData)
case "hex":
body = hex.EncodeToString(bodyData)
case "text":
fallthrough
default:
responseEncoding = "text"
body = string(bodyData)
}
data := map[string]any{
"url": destURL,
"status": resp.StatusCode,
"contentType": resp.GetHeader("content-type"),
"body": body,
"bodyEncoding": responseEncoding,
"headers": resp.Header,
"elapsed": elapsed.Milliseconds(),
}
ret.Data = data
//shortBody := ""
//if 64 > len(body) {
// shortBody = body
//} else {
// shortBody = body[:64]
//}
//
//logging.LogInfof("elapsed [%.1fs], length [%d], request [url=%s, headers=%s, content-type=%s, body=%s], status [%d], body [%s]",
// elapsed.Seconds(), len(bodyData), data["url"], headers, contentType, arg["payload"], data["status"], shortBody)
}
// 校验 IP 是否为私有内网地址
func isPrivateIP(ip net.IP) bool {
return ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsPrivate() || ip.IsUnspecified()
}
// 创建安全的 HTTP Client防止 SSRF 和 DNS 重绑定
func getSafeClient(timeout time.Duration) *req.Client {
dialer := &net.Dialer{
Timeout: timeout,
// Control 函数在解析出 IP 后、建立连接前执行,是防御 DNS Rebinding 的关键
Control: func(network, address string, c syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return err
}
ip := net.ParseIP(host)
if ip != nil && isPrivateIP(ip) {
return fmt.Errorf("ip address [%s] is prohibited", host)
}
return nil
},
}
client := req.C()
client.SetTimeout(timeout)
client.SetDial(dialer.DialContext)
client.SetRedirectPolicy(req.MaxRedirectPolicy(3))
return client
}