438 lines
11 KiB
Go
438 lines
11 KiB
Go
package service
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"xboard-go/internal/database"
|
|
"xboard-go/internal/model"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var serverTypeAliases = map[string]string{
|
|
"v2ray": "vmess",
|
|
"hysteria2": "hysteria",
|
|
}
|
|
|
|
var validServerTypes = map[string]struct{}{
|
|
"anytls": {},
|
|
"http": {},
|
|
"hysteria": {},
|
|
"mieru": {},
|
|
"naive": {},
|
|
"shadowsocks": {},
|
|
"socks": {},
|
|
"trojan": {},
|
|
"tuic": {},
|
|
"vless": {},
|
|
"vmess": {},
|
|
}
|
|
|
|
type NodeUser struct {
|
|
ID int `json:"id"`
|
|
UUID string `json:"uuid"`
|
|
SpeedLimit *int `json:"speed_limit,omitempty"`
|
|
DeviceLimit *int `json:"device_limit,omitempty"`
|
|
}
|
|
|
|
func NormalizeServerType(serverType string) string {
|
|
serverType = strings.ToLower(strings.TrimSpace(serverType))
|
|
if serverType == "" {
|
|
return ""
|
|
}
|
|
if alias, ok := serverTypeAliases[serverType]; ok {
|
|
return alias
|
|
}
|
|
return serverType
|
|
}
|
|
|
|
func IsValidServerType(serverType string) bool {
|
|
if serverType == "" {
|
|
return true
|
|
}
|
|
_, ok := validServerTypes[NormalizeServerType(serverType)]
|
|
return ok
|
|
}
|
|
|
|
func FindServer(nodeID, nodeType string) (*model.Server, error) {
|
|
query := database.DB.Model(&model.Server{})
|
|
if normalized := NormalizeServerType(nodeType); normalized != "" {
|
|
query = query.Where("type = ?", normalized)
|
|
}
|
|
|
|
var server model.Server
|
|
if err := query.
|
|
Where("code = ? OR id = ?", nodeID, nodeID).
|
|
Order(gorm.Expr("CASE WHEN code = ? THEN 0 ELSE 1 END", nodeID)).
|
|
First(&server).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
server.Type = NormalizeServerType(server.Type)
|
|
return &server, nil
|
|
}
|
|
|
|
func AvailableUsersForNode(node *model.Server) ([]NodeUser, error) {
|
|
groupIDs := parseIntSlice(node.GroupIDs)
|
|
if len(groupIDs) == 0 {
|
|
return []NodeUser{}, nil
|
|
}
|
|
|
|
var users []NodeUser
|
|
err := database.DB.Model(&model.User{}).
|
|
Select("id", "uuid", "speed_limit", "device_limit").
|
|
Where("group_id IN ?", groupIDs).
|
|
Where("u + d < transfer_enable").
|
|
Where("(expired_at >= ? OR expired_at IS NULL)", time.Now().Unix()).
|
|
Where("banned = ?", 0).
|
|
Scan(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
func AvailableServersForUser(user *model.User) ([]model.Server, error) {
|
|
var servers []model.Server
|
|
if err := database.DB.Where("`show` = ?", 1).Order("sort ASC").Find(&servers).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
filtered := make([]model.Server, 0, len(servers))
|
|
for _, server := range servers {
|
|
groupIDs := parseIntSlice(server.GroupIDs)
|
|
if user.GroupID != nil && len(groupIDs) > 0 && !containsInt(groupIDs, *user.GroupID) {
|
|
continue
|
|
}
|
|
|
|
if server.TransferEnable != nil && *server.TransferEnable > 0 && server.U+server.D >= *server.TransferEnable {
|
|
continue
|
|
}
|
|
|
|
filtered = append(filtered, server)
|
|
}
|
|
|
|
return filtered, nil
|
|
}
|
|
|
|
func CurrentRate(server *model.Server) float64 {
|
|
if !server.RateTimeEnable {
|
|
return float64(server.Rate)
|
|
}
|
|
|
|
ranges := parseObjectSlice(server.RateTimeRanges)
|
|
now := time.Now().Format("15:04")
|
|
for _, item := range ranges {
|
|
start, _ := item["start"].(string)
|
|
end, _ := item["end"].(string)
|
|
if start != "" && end != "" && now >= start && now <= end {
|
|
if rate, ok := toFloat64(item["rate"]); ok {
|
|
return rate
|
|
}
|
|
}
|
|
}
|
|
|
|
return float64(server.Rate)
|
|
}
|
|
|
|
func BuildNodeConfig(node *model.Server) map[string]any {
|
|
settings := parseObject(node.ProtocolSettings)
|
|
response := map[string]any{
|
|
"protocol": node.Type,
|
|
"listen_ip": "0.0.0.0",
|
|
"server_port": node.ServerPort,
|
|
"network": getMapString(settings, "network"),
|
|
"networkSettings": getMapAny(settings, "network_settings"),
|
|
}
|
|
|
|
switch node.Type {
|
|
case "shadowsocks":
|
|
response["cipher"] = getMapString(settings, "cipher")
|
|
response["plugin"] = getMapString(settings, "plugin")
|
|
response["plugin_opts"] = getMapString(settings, "plugin_opts")
|
|
cipher := getMapString(settings, "cipher")
|
|
switch cipher {
|
|
case "2022-blake3-aes-128-gcm":
|
|
response["server_key"] = serverKey(node.CreatedAt, 16)
|
|
case "2022-blake3-aes-256-gcm", "2022-blake3-chacha20-poly1305":
|
|
response["server_key"] = serverKey(node.CreatedAt, 32)
|
|
}
|
|
case "vmess":
|
|
response["tls"] = getMapInt(settings, "tls")
|
|
response["multiplex"] = getMapAny(settings, "multiplex")
|
|
case "trojan":
|
|
response["host"] = node.Host
|
|
response["server_name"] = getMapString(settings, "server_name")
|
|
response["multiplex"] = getMapAny(settings, "multiplex")
|
|
response["tls"] = getMapInt(settings, "tls")
|
|
if getMapInt(settings, "tls") == 2 {
|
|
response["tls_settings"] = getMapAny(settings, "reality_settings")
|
|
}
|
|
case "vless":
|
|
response["tls"] = getMapInt(settings, "tls")
|
|
response["flow"] = getMapString(settings, "flow")
|
|
response["multiplex"] = getMapAny(settings, "multiplex")
|
|
if encryption, ok := settings["encryption"].(map[string]any); ok {
|
|
if enabled, ok := encryption["enabled"].(bool); ok && enabled {
|
|
response["decryption"] = stringify(encryption["decryption"])
|
|
}
|
|
}
|
|
if getMapInt(settings, "tls") == 2 {
|
|
response["tls_settings"] = getMapAny(settings, "reality_settings")
|
|
} else {
|
|
response["tls_settings"] = getMapAny(settings, "tls_settings")
|
|
}
|
|
case "hysteria":
|
|
tls, _ := settings["tls"].(map[string]any)
|
|
obfs, _ := settings["obfs"].(map[string]any)
|
|
bandwidth, _ := settings["bandwidth"].(map[string]any)
|
|
version := getMapInt(settings, "version")
|
|
response["version"] = version
|
|
response["host"] = node.Host
|
|
response["server_name"] = stringify(tls["server_name"])
|
|
response["up_mbps"] = mapAnyInt(bandwidth, "up")
|
|
response["down_mbps"] = mapAnyInt(bandwidth, "down")
|
|
if version == 1 {
|
|
response["obfs"] = stringify(obfs["password"])
|
|
} else if version == 2 {
|
|
if open, ok := obfs["open"].(bool); ok && open {
|
|
response["obfs"] = stringify(obfs["type"])
|
|
response["obfs-password"] = stringify(obfs["password"])
|
|
}
|
|
}
|
|
case "tuic":
|
|
tls, _ := settings["tls"].(map[string]any)
|
|
response["version"] = getMapInt(settings, "version")
|
|
response["server_name"] = stringify(tls["server_name"])
|
|
response["congestion_control"] = getMapString(settings, "congestion_control")
|
|
response["tls_settings"] = getMapAny(settings, "tls_settings")
|
|
response["auth_timeout"] = "3s"
|
|
response["zero_rtt_handshake"] = false
|
|
response["heartbeat"] = "3s"
|
|
case "anytls":
|
|
tls, _ := settings["tls"].(map[string]any)
|
|
response["server_name"] = stringify(tls["server_name"])
|
|
response["padding_scheme"] = getMapAny(settings, "padding_scheme")
|
|
case "naive", "http":
|
|
response["tls"] = getMapInt(settings, "tls")
|
|
response["tls_settings"] = getMapAny(settings, "tls_settings")
|
|
case "mieru":
|
|
response["transport"] = getMapString(settings, "transport")
|
|
response["traffic_pattern"] = getMapString(settings, "traffic_pattern")
|
|
}
|
|
|
|
if routeIDs := parseIntSlice(node.RouteIDs); len(routeIDs) > 0 {
|
|
var routes []model.ServerRoute
|
|
if err := database.DB.Select("id", "`match`", "action", "action_value").Where("id IN ?", routeIDs).Find(&routes).Error; err == nil {
|
|
response["routes"] = routes
|
|
}
|
|
}
|
|
if value := parseGenericJSON(node.CustomOutbounds); value != nil {
|
|
response["custom_outbounds"] = value
|
|
}
|
|
if value := parseGenericJSON(node.CustomRoutes); value != nil {
|
|
response["custom_routes"] = value
|
|
}
|
|
if value := parseGenericJSON(node.CertConfig); value != nil {
|
|
response["cert_config"] = value
|
|
}
|
|
|
|
return pruneNilMap(response)
|
|
}
|
|
|
|
func ApplyTrafficDelta(userID int, node *model.Server, upload, download int64) {
|
|
rate := CurrentRate(node)
|
|
scaledUpload := int64(math.Round(float64(upload) * rate))
|
|
scaledDownload := int64(math.Round(float64(download) * rate))
|
|
|
|
database.DB.Model(&model.User{}).
|
|
Where("id = ?", userID).
|
|
Updates(map[string]any{
|
|
"u": gorm.Expr("u + ?", scaledUpload),
|
|
"d": gorm.Expr("d + ?", scaledDownload),
|
|
"t": time.Now().Unix(),
|
|
})
|
|
|
|
database.DB.Model(&model.Server{}).
|
|
Where("id = ?", node.ID).
|
|
Updates(map[string]any{
|
|
"u": gorm.Expr("u + ?", scaledUpload),
|
|
"d": gorm.Expr("d + ?", scaledDownload),
|
|
})
|
|
}
|
|
|
|
func serverKey(createdAt *time.Time, size int) string {
|
|
if createdAt == nil {
|
|
return ""
|
|
}
|
|
sum := md5.Sum([]byte(strconv.FormatInt(createdAt.Unix(), 10)))
|
|
hex := fmt.Sprintf("%x", sum)
|
|
if size > len(hex) {
|
|
size = len(hex)
|
|
}
|
|
return base64.StdEncoding.EncodeToString([]byte(hex[:size]))
|
|
}
|
|
|
|
func parseIntSlice(raw *string) []int {
|
|
if raw == nil || strings.TrimSpace(*raw) == "" {
|
|
return nil
|
|
}
|
|
|
|
var decoded []any
|
|
if err := json.Unmarshal([]byte(*raw), &decoded); err == nil {
|
|
result := make([]int, 0, len(decoded))
|
|
for _, item := range decoded {
|
|
if value, ok := toInt(item); ok {
|
|
result = append(result, value)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
parts := strings.Split(*raw, ",")
|
|
result := make([]int, 0, len(parts))
|
|
for _, part := range parts {
|
|
if value, err := strconv.Atoi(strings.TrimSpace(part)); err == nil {
|
|
result = append(result, value)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func parseObject(raw *string) map[string]any {
|
|
if raw == nil || strings.TrimSpace(*raw) == "" {
|
|
return map[string]any{}
|
|
}
|
|
var decoded map[string]any
|
|
if err := json.Unmarshal([]byte(*raw), &decoded); err != nil {
|
|
return map[string]any{}
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func parseObjectSlice(raw *string) []map[string]any {
|
|
if raw == nil || strings.TrimSpace(*raw) == "" {
|
|
return nil
|
|
}
|
|
var decoded []map[string]any
|
|
if err := json.Unmarshal([]byte(*raw), &decoded); err != nil {
|
|
return nil
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func parseGenericJSON(raw *string) any {
|
|
if raw == nil || strings.TrimSpace(*raw) == "" {
|
|
return nil
|
|
}
|
|
var decoded any
|
|
if err := json.Unmarshal([]byte(*raw), &decoded); err != nil {
|
|
return nil
|
|
}
|
|
return decoded
|
|
}
|
|
|
|
func getMapString(values map[string]any, key string) string {
|
|
return stringify(values[key])
|
|
}
|
|
|
|
func getMapInt(values map[string]any, key string) int {
|
|
if value, ok := toInt(values[key]); ok {
|
|
return value
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func getMapAny(values map[string]any, key string) any {
|
|
if value, ok := values[key]; ok {
|
|
return value
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func mapAnyInt(values map[string]any, key string) int {
|
|
if value, ok := toInt(values[key]); ok {
|
|
return value
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func stringify(value any) string {
|
|
switch typed := value.(type) {
|
|
case string:
|
|
return typed
|
|
case fmt.Stringer:
|
|
return typed.String()
|
|
case float64:
|
|
return strconv.FormatInt(int64(typed), 10)
|
|
case int:
|
|
return strconv.Itoa(typed)
|
|
case int64:
|
|
return strconv.FormatInt(typed, 10)
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func pruneNilMap(values map[string]any) map[string]any {
|
|
result := make(map[string]any, len(values))
|
|
for key, value := range values {
|
|
if value == nil {
|
|
continue
|
|
}
|
|
if text, ok := value.(string); ok && text == "" {
|
|
continue
|
|
}
|
|
result[key] = value
|
|
}
|
|
return result
|
|
}
|
|
|
|
func toInt(value any) (int, bool) {
|
|
switch typed := value.(type) {
|
|
case int:
|
|
return typed, true
|
|
case int64:
|
|
return int(typed), true
|
|
case float64:
|
|
return int(typed), true
|
|
case string:
|
|
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
|
|
return parsed, err == nil
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func toFloat64(value any) (float64, bool) {
|
|
switch typed := value.(type) {
|
|
case float64:
|
|
return typed, true
|
|
case int:
|
|
return float64(typed), true
|
|
case int64:
|
|
return float64(typed), true
|
|
case string:
|
|
parsed, err := strconv.ParseFloat(strings.TrimSpace(typed), 64)
|
|
return parsed, err == nil
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func containsInt(values []int, target int) bool {
|
|
for _, value := range values {
|
|
if value == target {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|