498 lines
14 KiB
Go
498 lines
14 KiB
Go
package service
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"xboard-go/internal/database"
|
|
"xboard-go/internal/model"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var serverTypeAliases = map[string]string{
|
|
"v2ray": "vmess",
|
|
"hysteria2": "hysteria",
|
|
}
|
|
|
|
var validServerTypes = map[string]struct{}{
|
|
"anytls": {},
|
|
"http": {},
|
|
"hysteria": {},
|
|
"mieru": {},
|
|
"naive": {},
|
|
"shadowsocks": {},
|
|
"socks": {},
|
|
"trojan": {},
|
|
"tuic": {},
|
|
"vless": {},
|
|
"vmess": {},
|
|
}
|
|
|
|
type NodeUser struct {
|
|
ID int `json:"id"`
|
|
UUID string `json:"uuid"`
|
|
SpeedLimit *int `json:"speed_limit,omitempty"`
|
|
DeviceLimit *int `json:"device_limit,omitempty"`
|
|
}
|
|
|
|
type NodeBaseConfig struct {
|
|
PushInterval int `json:"push_interval"`
|
|
PullInterval int `json:"pull_interval"`
|
|
}
|
|
|
|
type NodeServerConfig struct {
|
|
Name string `json:"-"`
|
|
Protocol string `json:"protocol"`
|
|
RawHost string `json:"-"`
|
|
ListenIP string `json:"listen_ip"`
|
|
ServerPort int `json:"server_port"`
|
|
Network any `json:"network"`
|
|
NetworkSettings any `json:"networkSettings"`
|
|
Cipher any `json:"cipher,omitempty"`
|
|
Plugin any `json:"plugin,omitempty"`
|
|
PluginOpts any `json:"plugin_opts,omitempty"`
|
|
ServerKey any `json:"server_key,omitempty"`
|
|
Host any `json:"host,omitempty"`
|
|
ServerName any `json:"server_name,omitempty"`
|
|
Tls any `json:"tls,omitempty"`
|
|
TlsSettings any `json:"tls_settings,omitempty"`
|
|
Flow any `json:"flow,omitempty"`
|
|
Multiplex any `json:"multiplex,omitempty"`
|
|
UpMbps any `json:"up_mbps,omitempty"`
|
|
DownMbps any `json:"down_mbps,omitempty"`
|
|
Version any `json:"version,omitempty"`
|
|
Obfs any `json:"obfs,omitempty"`
|
|
ObfsPassword any `json:"obfs-password,omitempty"`
|
|
CongestionControl any `json:"congestion_control,omitempty"`
|
|
AuthTimeout any `json:"auth_timeout,omitempty"`
|
|
ZeroRTTHandshake any `json:"zero_rtt_handshake,omitempty"`
|
|
Heartbeat any `json:"heartbeat,omitempty"`
|
|
PaddingScheme any `json:"padding_scheme,omitempty"`
|
|
Transport any `json:"transport,omitempty"`
|
|
TrafficPattern any `json:"traffic_pattern,omitempty"`
|
|
Decryption any `json:"decryption,omitempty"`
|
|
Routes []model.ServerRoute `json:"routes,omitempty"`
|
|
CustomOutbounds any `json:"custom_outbounds,omitempty"`
|
|
CustomRoutes any `json:"custom_routes,omitempty"`
|
|
CertConfig any `json:"cert_config,omitempty"`
|
|
BaseConfig NodeBaseConfig `json:"base_config"`
|
|
}
|
|
|
|
func NormalizeServerType(serverType string) string {
|
|
serverType = strings.ToLower(strings.TrimSpace(serverType))
|
|
if serverType == "" {
|
|
return ""
|
|
}
|
|
if alias, ok := serverTypeAliases[serverType]; ok {
|
|
return alias
|
|
}
|
|
return serverType
|
|
}
|
|
|
|
func IsValidServerType(serverType string) bool {
|
|
if serverType == "" {
|
|
return true
|
|
}
|
|
_, ok := validServerTypes[NormalizeServerType(serverType)]
|
|
return ok
|
|
}
|
|
|
|
func FindServer(nodeID, nodeType string) (*model.Server, error) {
|
|
query := database.DB.Model(&model.Server{})
|
|
if normalized := NormalizeServerType(nodeType); normalized != "" {
|
|
query = query.Where("type = ?", normalized)
|
|
}
|
|
|
|
var server model.Server
|
|
if err := query.
|
|
Where("code = ? OR id = ?", nodeID, nodeID).
|
|
Order(gorm.Expr("CASE WHEN code = ? THEN 0 ELSE 1 END", nodeID)).
|
|
First(&server).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
server.Type = NormalizeServerType(server.Type)
|
|
return &server, nil
|
|
}
|
|
|
|
func AvailableUsersForNode(node *model.Server) ([]NodeUser, error) {
|
|
groupIDs := parseIntSlice(node.GroupIDs)
|
|
if len(groupIDs) == 0 {
|
|
return []NodeUser{}, nil
|
|
}
|
|
|
|
var users []NodeUser
|
|
err := database.DB.Model(&model.User{}).
|
|
Select("id", "uuid", "speed_limit", "device_limit").
|
|
Where("group_id IN ?", groupIDs).
|
|
Where("u + d < transfer_enable").
|
|
Where("(expired_at >= ? OR expired_at IS NULL)", time.Now().Unix()).
|
|
Where("banned = ?", 0).
|
|
Scan(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
func AvailableServersForUser(user *model.User) ([]model.Server, error) {
|
|
var servers []model.Server
|
|
if err := database.DB.Where("`show` = ?", 1).Order("sort ASC").Find(&servers).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
filtered := make([]model.Server, 0, len(servers))
|
|
for _, server := range servers {
|
|
groupIDs := parseIntSlice(server.GroupIDs)
|
|
if user.GroupID != nil && len(groupIDs) > 0 && !containsInt(groupIDs, *user.GroupID) {
|
|
continue
|
|
}
|
|
|
|
if server.TransferEnable != nil && *server.TransferEnable > 0 && server.U+server.D >= *server.TransferEnable {
|
|
continue
|
|
}
|
|
|
|
filtered = append(filtered, server)
|
|
}
|
|
|
|
return filtered, nil
|
|
}
|
|
|
|
func CurrentRate(server *model.Server) float64 {
|
|
if !server.RateTimeEnable {
|
|
return float64(server.Rate)
|
|
}
|
|
|
|
ranges := parseObjectSlice(server.RateTimeRanges)
|
|
now := time.Now().Format("15:04")
|
|
for _, item := range ranges {
|
|
start, _ := item["start"].(string)
|
|
end, _ := item["end"].(string)
|
|
if start != "" && end != "" && now >= start && now <= end {
|
|
if rate, ok := toFloat64(item["rate"]); ok {
|
|
return rate
|
|
}
|
|
}
|
|
}
|
|
|
|
return float64(server.Rate)
|
|
}
|
|
|
|
func BuildNodeConfig(node *model.Server) NodeServerConfig {
|
|
settings := parseObject(node.ProtocolSettings)
|
|
response := NodeServerConfig{
|
|
Name: node.Name,
|
|
Protocol: node.Type,
|
|
RawHost: node.Host,
|
|
ListenIP: "0.0.0.0",
|
|
ServerPort: node.ServerPort,
|
|
Network: getMapAny(settings, "network"),
|
|
NetworkSettings: getMapAny(settings, "network_settings"),
|
|
BaseConfig: NodeBaseConfig{
|
|
PushInterval: MustGetInt("server_push_interval", 60),
|
|
PullInterval: MustGetInt("server_pull_interval", 60),
|
|
},
|
|
}
|
|
|
|
switch node.Type {
|
|
case "shadowsocks":
|
|
response.Cipher = getMapString(settings, "cipher")
|
|
response.Plugin = getMapAny(settings, "plugin")
|
|
response.PluginOpts = getMapAny(settings, "plugin_opts")
|
|
cipher := getMapString(settings, "cipher")
|
|
response.ServerKey = ""
|
|
switch cipher {
|
|
case "2022-blake3-aes-128-gcm":
|
|
response.ServerKey = serverKey(node.CreatedAt, 16)
|
|
case "2022-blake3-aes-256-gcm", "2022-blake3-chacha20-poly1305":
|
|
response.ServerKey = serverKey(node.CreatedAt, 32)
|
|
}
|
|
case "vmess":
|
|
response.Tls = getMapInt(settings, "tls")
|
|
response.Multiplex = getMapAny(settings, "multiplex")
|
|
case "trojan":
|
|
response.Host = node.Host
|
|
response.ServerName = getMapString(settings, "server_name")
|
|
response.Multiplex = getMapAny(settings, "multiplex")
|
|
response.Tls = getMapInt(settings, "tls")
|
|
if getMapInt(settings, "tls") == 2 {
|
|
response.TlsSettings = getMapAny(settings, "reality_settings")
|
|
} else {
|
|
response.TlsSettings = getMapAny(settings, "tls_settings")
|
|
}
|
|
case "vless":
|
|
response.Tls = getMapInt(settings, "tls")
|
|
response.Flow = getMapString(settings, "flow")
|
|
response.Multiplex = getMapAny(settings, "multiplex")
|
|
response.Decryption = nil
|
|
if encryption, ok := settings["encryption"].(map[string]any); ok {
|
|
if enabled, ok := encryption["enabled"].(bool); ok && enabled {
|
|
response.Decryption = stringify(encryption["decryption"])
|
|
}
|
|
}
|
|
if getMapInt(settings, "tls") == 2 {
|
|
response.TlsSettings = getMapAny(settings, "reality_settings")
|
|
} else {
|
|
response.TlsSettings = getMapAny(settings, "tls_settings")
|
|
}
|
|
case "hysteria":
|
|
tls, _ := settings["tls"].(map[string]any)
|
|
obfs, _ := settings["obfs"].(map[string]any)
|
|
bandwidth, _ := settings["bandwidth"].(map[string]any)
|
|
version := getMapInt(settings, "version")
|
|
response.Version = version
|
|
response.Host = node.Host
|
|
response.ServerName = stringify(tls["server_name"])
|
|
response.UpMbps = mapAnyInt(bandwidth, "up")
|
|
response.DownMbps = mapAnyInt(bandwidth, "down")
|
|
response.Obfs = nil
|
|
response.ObfsPassword = nil
|
|
if version == 1 {
|
|
response.Obfs = stringify(obfs["password"])
|
|
} else if version == 2 {
|
|
if open, ok := obfs["open"].(bool); ok && open {
|
|
response.Obfs = stringify(obfs["type"])
|
|
response.ObfsPassword = stringify(obfs["password"])
|
|
}
|
|
}
|
|
case "tuic":
|
|
tls, _ := settings["tls"].(map[string]any)
|
|
response.Version = getMapInt(settings, "version")
|
|
response.ServerName = stringify(tls["server_name"])
|
|
response.CongestionControl = getMapString(settings, "congestion_control")
|
|
response.TlsSettings = getMapAny(settings, "tls_settings")
|
|
response.AuthTimeout = "3s"
|
|
response.ZeroRTTHandshake = false
|
|
response.Heartbeat = "3s"
|
|
case "anytls":
|
|
tls, _ := settings["tls"].(map[string]any)
|
|
response.ServerName = stringify(tls["server_name"])
|
|
response.PaddingScheme = getMapAny(settings, "padding_scheme")
|
|
case "naive", "http":
|
|
response.Tls = getMapInt(settings, "tls")
|
|
response.TlsSettings = getMapAny(settings, "tls_settings")
|
|
case "mieru":
|
|
response.Transport = getMapString(settings, "transport")
|
|
response.TrafficPattern = getMapString(settings, "traffic_pattern")
|
|
}
|
|
|
|
response.Routes = nil
|
|
if routeIDs := parseIntSlice(node.RouteIDs); len(routeIDs) > 0 {
|
|
var routes []model.ServerRoute
|
|
if err := database.DB.Select("id", "`match`", "action", "action_value").Where("id IN ?", routeIDs).Find(&routes).Error; err == nil {
|
|
response.Routes = routes
|
|
}
|
|
}
|
|
response.CustomOutbounds = parseGenericJSON(node.CustomOutbounds)
|
|
response.CustomRoutes = parseGenericJSON(node.CustomRoutes)
|
|
response.CertConfig = parseGenericJSON(node.CertConfig)
|
|
|
|
return response
|
|
}
|
|
|
|
func ApplyTrafficDelta(userID int, node *model.Server, upload, download int64) {
|
|
groupIDs := parseIntSlice(node.GroupIDs)
|
|
if len(groupIDs) > 0 {
|
|
var count int64
|
|
if err := database.DB.Model(&model.User{}).
|
|
Where("id = ? AND group_id IN ?", userID, groupIDs).
|
|
Count(&count).Error; err != nil || count == 0 {
|
|
return
|
|
}
|
|
}
|
|
|
|
rate := CurrentRate(node)
|
|
scaledUpload := int64(math.Round(float64(upload) * rate))
|
|
scaledDownload := int64(math.Round(float64(download) * rate))
|
|
|
|
database.DB.Model(&model.User{}).
|
|
Where("id = ?", userID).
|
|
Updates(map[string]any{
|
|
"u": gorm.Expr("u + ?", scaledUpload),
|
|
"d": gorm.Expr("d + ?", scaledDownload),
|
|
"t": time.Now().Unix(),
|
|
})
|
|
|
|
database.DB.Model(&model.Server{}).
|
|
Where("id = ?", node.ID).
|
|
Updates(map[string]any{
|
|
"u": gorm.Expr("u + ?", scaledUpload),
|
|
"d": gorm.Expr("d + ?", scaledDownload),
|
|
})
|
|
}
|
|
|
|
func serverKey(createdAt *time.Time, size int) string {
|
|
if createdAt == nil {
|
|
return ""
|
|
}
|
|
sum := md5.Sum([]byte(strconv.FormatInt(createdAt.Unix(), 10)))
|
|
hex := fmt.Sprintf("%x", sum)
|
|
if size > len(hex) {
|
|
size = len(hex)
|
|
}
|
|
return base64.StdEncoding.EncodeToString([]byte(hex[:size]))
|
|
}
|
|
|
|
func parseIntSlice(raw *string) []int {
|
|
if raw == nil || strings.TrimSpace(*raw) == "" {
|
|
return nil
|
|
}
|
|
|
|
var decoded []any
|
|
if err := json.Unmarshal([]byte(*raw), &decoded); err == nil {
|
|
result := make([]int, 0, len(decoded))
|
|
for _, item := range decoded {
|
|
if value, ok := toInt(item); ok {
|
|
result = append(result, value)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
parts := strings.Split(*raw, ",")
|
|
result := make([]int, 0, len(parts))
|
|
for _, part := range parts {
|
|
if value, err := strconv.Atoi(strings.TrimSpace(part)); err == nil {
|
|
result = append(result, value)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func parseObject(raw *string) map[string]any {
|
|
if raw == nil || strings.TrimSpace(*raw) == "" {
|
|
return map[string]any{}
|
|
}
|
|
var decoded map[string]any
|
|
if err := json.Unmarshal([]byte(*raw), &decoded); err != nil {
|
|
return map[string]any{}
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func parseObjectSlice(raw *string) []map[string]any {
|
|
if raw == nil || strings.TrimSpace(*raw) == "" {
|
|
return nil
|
|
}
|
|
var decoded []map[string]any
|
|
if err := json.Unmarshal([]byte(*raw), &decoded); err != nil {
|
|
return nil
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func parseGenericJSON(raw *string) any {
|
|
if raw == nil || strings.TrimSpace(*raw) == "" {
|
|
return nil
|
|
}
|
|
var decoded any
|
|
if err := json.Unmarshal([]byte(*raw), &decoded); err != nil {
|
|
return nil
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func getMapString(values map[string]any, key string) string {
|
|
return stringify(values[key])
|
|
}
|
|
|
|
func getMapInt(values map[string]any, key string) int {
|
|
if value, ok := toInt(values[key]); ok {
|
|
return value
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func getMapAny(values map[string]any, key string) any {
|
|
if value, ok := values[key]; ok {
|
|
return value
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func mapAnyInt(values map[string]any, key string) int {
|
|
if value, ok := toInt(values[key]); ok {
|
|
return value
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func stringify(value any) string {
|
|
switch typed := value.(type) {
|
|
case string:
|
|
return typed
|
|
case fmt.Stringer:
|
|
return typed.String()
|
|
case float64:
|
|
return strconv.FormatInt(int64(typed), 10)
|
|
case int:
|
|
return strconv.Itoa(typed)
|
|
case int64:
|
|
return strconv.FormatInt(typed, 10)
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func pruneNilMap(values map[string]any) map[string]any {
|
|
result := make(map[string]any, len(values))
|
|
for key, value := range values {
|
|
if value == nil {
|
|
continue
|
|
}
|
|
if text, ok := value.(string); ok && text == "" {
|
|
continue
|
|
}
|
|
result[key] = value
|
|
}
|
|
return result
|
|
}
|
|
|
|
func toInt(value any) (int, bool) {
|
|
switch typed := value.(type) {
|
|
case int:
|
|
return typed, true
|
|
case int64:
|
|
return int(typed), true
|
|
case float64:
|
|
return int(typed), true
|
|
case string:
|
|
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
|
|
return parsed, err == nil
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func toFloat64(value any) (float64, bool) {
|
|
switch typed := value.(type) {
|
|
case float64:
|
|
return typed, true
|
|
case int:
|
|
return float64(typed), true
|
|
case int64:
|
|
return float64(typed), true
|
|
case string:
|
|
parsed, err := strconv.ParseFloat(strings.TrimSpace(typed), 64)
|
|
return parsed, err == nil
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func containsInt(values []int, target int) bool {
|
|
for _, value := range values {
|
|
if value == target {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|