WebRtsp/main.go
2025-12-09 17:31:21 +08:00

958 lines
26 KiB
Go

package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/exec"
"os/signal"
"runtime"
"runtime/debug"
"strings"
"sync"
"syscall"
"time"
"github.com/patrickmn/go-cache"
"github.com/pion/webrtc/v3"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/time/rate"
"gopkg.in/yaml.v3"
)
type StreamConfig struct {
URL string `json:"url"`
Name string `json:"name"`
}
type OfferRequest struct {
Offer webrtc.SessionDescription `json:"offer"`
StreamConfig StreamConfig `json:"streamConfig"`
}
// 添加配置结构
type Config struct {
Server struct {
Port string `yaml:"port"`
ReadTimeout time.Duration `yaml:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout"`
MaxHeaderBytes int `yaml:"max_header_bytes"`
} `yaml:"server"`
FFmpeg struct {
MaxConcurrentStreams int `yaml:"max_concurrent_streams"`
Preset string `yaml:"preset"`
Bitrate string `yaml:"bitrate"`
Maxrate string `yaml:"maxrate"`
BufferSize string `yaml:"buffer_size"`
Keyint int `yaml:"keyint"`
MinKeyint int `yaml:"min_keyint"`
Threads int `yaml:"threads"`
FrameRate int `yaml:"frame_rate"`
GopSize int `yaml:"gop_size"`
ScaleWidth int `yaml:"scale_width"`
ScaleHeight int `yaml:"scale_height"`
} `yaml:"ffmpeg"`
Retry struct {
MaxRetries int `yaml:"max_retries"`
RetryInterval time.Duration `yaml:"retry_interval"`
} `yaml:"retry"`
Metrics struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
} `yaml:"metrics"`
Cache struct {
Enabled bool `yaml:"enabled"`
MaxSize int `yaml:"max_size"`
ExpireTime time.Duration `yaml:"expire_time"`
} `yaml:"cache"`
}
// 添加 Prometheus 指标
var (
activeStreamsGauge = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "webrtc_active_streams",
Help: "Number of active WebRTC streams",
})
ffmpegErrorsCounter = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ffmpeg_errors_total",
Help: "Total number of FFmpeg errors",
}, []string{"stream_name"})
streamLatencyHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "stream_processing_latency_seconds",
Help: "Latency of stream processing in seconds",
Buckets: prometheus.LinearBuckets(0, 0.1, 10), // 0-1s, 100ms buckets
}, []string{"stream_name"})
writeTimeoutCounter = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "write_timeouts_total",
Help: "Total number of write timeouts",
}, []string{"stream_name"})
streamProcessingDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "stream_processing_duration_seconds",
Help: "Time spent processing each stream",
Buckets: prometheus.ExponentialBuckets(0.1, 2, 10),
},
[]string{"stream_name"},
)
requestDurationHistogram = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "HTTP request duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"path", "method", "status"},
)
memoryUsageGauge = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "app_memory_usage_bytes",
Help: "Current memory usage in bytes",
})
)
var (
appCache *cache.Cache
)
// 添加自定义错误类型
type StreamError struct {
Code int
Message string
Err error
}
func (e *StreamError) Error() string {
if e.Err != nil {
return fmt.Sprintf("%s: %v", e.Message, e.Err)
}
return e.Message
}
func loadConfig() (*Config, error) {
f, err := os.Open("config.yaml")
if err != nil {
return nil, fmt.Errorf("error opening config file: %v", err)
}
defer f.Close()
var cfg Config
decoder := yaml.NewDecoder(f)
if err := decoder.Decode(&cfg); err != nil {
return nil, fmt.Errorf("error decoding config file: %v", err)
}
// 添加配置验证
if err := validateConfig(&cfg); err != nil {
return nil, fmt.Errorf("invalid configuration: %v", err)
}
// 打印当前配置
log.Printf("Server Configuration:")
log.Printf(" Port: %s", cfg.Server.Port)
log.Printf(" Read Timeout: %v", cfg.Server.ReadTimeout)
log.Printf(" Write Timeout: %v", cfg.Server.WriteTimeout)
log.Printf("FFmpeg Configuration:")
log.Printf(" Max Concurrent Streams: %d", cfg.FFmpeg.MaxConcurrentStreams)
log.Printf(" Preset: %s", cfg.FFmpeg.Preset)
log.Printf(" Bitrate: %s", cfg.FFmpeg.Bitrate)
// ... 其他配置项
return &cfg, nil
}
func validateConfig(cfg *Config) error {
if cfg.Server.Port == "" {
return fmt.Errorf("server port is required")
}
if cfg.Server.ReadTimeout <= 0 {
return fmt.Errorf("server read timeout must be positive")
}
if cfg.Server.WriteTimeout <= 0 {
return fmt.Errorf("server write timeout must be positive")
}
if cfg.FFmpeg.MaxConcurrentStreams <= 0 {
return fmt.Errorf("max concurrent streams must be positive")
}
if cfg.FFmpeg.FrameRate <= 0 {
return fmt.Errorf("frame rate must be positive")
}
// 添加 FFmpeg 配置验证
if cfg.FFmpeg.Preset == "" {
return fmt.Errorf("ffmpeg preset is required")
}
if cfg.FFmpeg.Bitrate == "" {
return fmt.Errorf("ffmpeg bitrate is required")
}
if cfg.FFmpeg.Maxrate == "" {
return fmt.Errorf("ffmpeg maxrate is required")
}
if cfg.FFmpeg.BufferSize == "" {
return fmt.Errorf("ffmpeg buffer size is required")
}
if cfg.FFmpeg.ScaleWidth <= 0 {
return fmt.Errorf("ffmpeg scale width must be positive")
}
if cfg.FFmpeg.ScaleHeight <= 0 {
return fmt.Errorf("ffmpeg scale height must be positive")
}
// 添加重试配置验证
if cfg.Retry.MaxRetries < 0 {
return fmt.Errorf("retry max retries must be non-negative")
}
if cfg.Retry.RetryInterval <= 0 {
return fmt.Errorf("retry interval must be positive")
}
// ... 其他验证
return nil
}
func initMetrics() {
prometheus.MustRegister(activeStreamsGauge)
prometheus.MustRegister(ffmpegErrorsCounter)
prometheus.MustRegister(streamLatencyHistogram)
prometheus.MustRegister(writeTimeoutCounter)
prometheus.MustRegister(streamProcessingDuration)
prometheus.MustRegister(requestDurationHistogram)
prometheus.MustRegister(memoryUsageGauge)
}
func recoveryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
log.Printf("Panic recovered: %v\nStack trace: %s", err, debug.Stack())
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
type LogEntry struct {
Level string `json:"level"`
Timestamp time.Time `json:"timestamp"`
Message string `json:"message"`
StreamID string `json:"stream_id,omitempty"`
Error string `json:"error,omitempty"`
}
func logError(format string, v ...interface{}) {
entry := LogEntry{
Level: "ERROR",
Timestamp: time.Now(),
Message: fmt.Sprintf(format, v...),
}
json.NewEncoder(os.Stderr).Encode(entry)
}
func logInfo(format string, v ...interface{}) {
entry := LogEntry{
Level: "INFO",
Timestamp: time.Now(),
Message: fmt.Sprintf(format, v...),
}
json.NewEncoder(os.Stdout).Encode(entry)
}
func securityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
next.ServeHTTP(w, r)
})
}
func gracefulShutdown(server *http.Server, timeout time.Duration) {
done := make(chan bool)
go func() {
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
<-signalChan
log.Println("Shutdown signal received")
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
server.SetKeepAlivesEnabled(false)
if err := server.Shutdown(ctx); err != nil {
log.Printf("Could not gracefully shutdown the server: %v\n", err)
}
close(done)
}()
}
// 添加 CORS 中间件
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 允许特定的源
allowedOrigins := []string{
"http://127.0.0.1:5500",
"http://localhost:5500",
// 添加其他需要的源
}
origin := r.Header.Get("Origin")
for _, allowedOrigin := range allowedOrigins {
if origin == allowedOrigin {
w.Header().Set("Access-Control-Allow-Origin", origin)
break
}
}
// 允许的请求方法
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
// 允许的请求头
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
// 允许凭证
w.Header().Set("Access-Control-Allow-Credentials", "true")
// 设置预检请求的缓存时间
w.Header().Set("Access-Control-Max-Age", "86400")
// 处理预检请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
func main() {
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
mediaEngine := webrtc.MediaEngine{}
if err := mediaEngine.RegisterDefaultCodecs(); err != nil {
log.Fatal("Failed to register default codecs:", err)
}
api := webrtc.NewAPI(webrtc.WithMediaEngine(&mediaEngine))
// 加载配置
cfg, err := loadConfig()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化指标
if cfg.Metrics.Enabled {
initMetrics()
}
// 初始化缓存
if cfg.Cache.Enabled {
appCache = cache.New(cfg.Cache.ExpireTime, 2*cfg.Cache.ExpireTime)
log.Println("Cache initialized")
}
// 修改资源限制部分
var (
activeStreams = make(chan struct{}, cfg.FFmpeg.MaxConcurrentStreams)
streamsMutex sync.RWMutex
activeStreamCount int
)
// 添加 Prometheus metrics endpoint
if cfg.Metrics.Enabled {
http.Handle(cfg.Metrics.Path, promhttp.Handler())
}
// 创建用于优雅关闭的信号处理
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
<-signalChan
log.Println("Received shutdown signal, cleaning up...")
cancel()
}()
// 创建 mux 并应用中间件
mux := http.NewServeMux()
// 注册路由
mux.HandleFunc("/offer", func(w http.ResponseWriter, r *http.Request) {
requestID := fmt.Sprintf("%d", time.Now().UnixNano())
log.Printf("[%s] Received new offer request", requestID)
// 在处理新请求前检是否达到最大并发数
select {
case activeStreams <- struct{}{}: // 获取令牌
streamsMutex.Lock()
activeStreamCount++
activeStreamsGauge.Set(float64(activeStreamCount))
streamsMutex.Unlock()
default:
http.Error(w, "Max concurrent streams reached", http.StatusServiceUnavailable)
return
}
log.Println("Received /offer request")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
const (
maxRequestBodySize = 1 << 20 // 1MB
)
// 在处理请求时添加大小限制
body, err := io.ReadAll(io.LimitReader(r.Body, maxRequestBodySize))
if err != nil {
log.Println("Failed to read request body:", err)
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
defer r.Body.Close()
var offerReq OfferRequest
if err := json.Unmarshal(body, &offerReq); err != nil {
log.Println("Failed to parse request:", err)
http.Error(w, "Invalid request format", http.StatusBadRequest)
return
}
// 添加流配置验证
if err := validateStreamConfig(offerReq.StreamConfig); err != nil {
log.Printf("Invalid stream config: %v", err)
http.Error(w, fmt.Sprintf("Invalid stream config: %v", err), http.StatusBadRequest)
return
}
// 使用缓存储或检某些数据(例如,已处理的流)
if cfg.Cache.Enabled {
cacheKey := fmt.Sprintf("%s_%s", offerReq.StreamConfig.URL, offerReq.StreamConfig.Name)
if _, found := appCache.Get(cacheKey); found {
log.Printf("Stream %s is already being processed\n", offerReq.StreamConfig.Name)
http.Error(w, "Stream already being processed", http.StatusConflict)
<-activeStreams // 释放令牌
return
}
// 使用新的缓存键
appCache.Set(cacheKey, true, cache.DefaultExpiration)
}
// 在处理请求前先测试 RTSP 流是否可访问
if err := testRTSPStream(offerReq.StreamConfig.URL); err != nil {
log.Printf("RTSP stream test failed: %v", err)
http.Error(w, fmt.Sprintf("RTSP stream test failed: %v", err), http.StatusBadRequest)
return
}
configuration := webrtc.Configuration{}
peerConnection, err := api.NewPeerConnection(configuration)
if err != nil {
log.Println("Failed to create peer connection:", err)
http.Error(w, "Failed to create peer connection", http.StatusInternalServerError)
if cfg.Cache.Enabled {
appCache.Delete(offerReq.StreamConfig.Name) // 处理失败,移除缓存
}
return
}
log.Printf("[%s] Peer connection created successfully", requestID)
var wg sync.WaitGroup
done := make(chan struct{})
cmdDone := make(chan struct{})
var cmd *exec.Cmd
var cmdMutex sync.Mutex
var cleanupOnce sync.Once
var cleanupMutex sync.Mutex
isCleanedUp := false
cleanup := func() {
cleanupOnce.Do(func() {
cleanupMutex.Lock()
if isCleanedUp {
cleanupMutex.Unlock()
return
}
isCleanedUp = true
cleanupMutex.Unlock()
log.Printf("Cleanup triggered for stream: %s\n", offerReq.StreamConfig.URL)
// 使用相同的缓存键进行删除
if cfg.Cache.Enabled {
cacheKey := fmt.Sprintf("%s_%s", offerReq.StreamConfig.URL, offerReq.StreamConfig.Name)
appCache.Delete(cacheKey)
}
select {
case <-done:
// channel 已经关闭
default:
close(done)
}
// 等待 FFmpeg 处理完成
select {
case <-cmdDone:
case <-time.After(5 * time.Second):
log.Println("FFmpeg cleanup timed out")
}
wg.Wait()
if err := peerConnection.Close(); err != nil {
log.Printf("Failed to close peer connection: %v", err)
}
if cfg.Cache.Enabled {
appCache.Delete(offerReq.StreamConfig.Name)
}
// 确 FFmpeg 进程被终止
cmdMutex.Lock()
if cmd != nil && cmd.Process != nil {
if err := cmd.Process.Kill(); err != nil {
if !strings.Contains(err.Error(), "process already finished") {
log.Printf("Error killing FFmpeg process: %v", err)
}
}
// 等待进程完全退出
if err := cmd.Wait(); err != nil {
log.Printf("Error waiting for FFmpeg process to exit: %v", err)
}
}
cmdMutex.Unlock()
log.Printf("Cleanup completed for stream: %s\n", offerReq.StreamConfig.URL)
})
}
// 修改这里:使用符合规范的 streamID
streamID := fmt.Sprintf("stream_%s", strings.ReplaceAll(offerReq.StreamConfig.Name, " ", "_"))
videoTrack, err := webrtc.NewTrackLocalStaticSample(
webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264},
"video",
streamID,
)
if err != nil {
log.Println("Failed to create video track:", err)
cleanup()
http.Error(w, "Failed to create video track", http.StatusInternalServerError)
return
}
rtpSender, err := peerConnection.AddTrack(videoTrack)
if err != nil {
log.Println("Failed to add video track:", err)
cleanup()
http.Error(w, "Failed to add video track", http.StatusInternalServerError)
return
}
wg.Add(1)
go func() {
defer wg.Done()
rtcpBuf := make([]byte, 1500)
for {
select {
case <-done:
return
default:
if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil {
if rtcpErr != io.EOF {
log.Printf("rtcp error: %v", rtcpErr)
}
return
}
}
}
}()
peerConnection.OnConnectionStateChange(func(s webrtc.PeerConnectionState) {
log.Printf("Peer Connection State has changed to %s for stream: %s\n", s.String(), offerReq.StreamConfig.URL)
switch s {
case webrtc.PeerConnectionStateFailed,
webrtc.PeerConnectionStateClosed,
webrtc.PeerConnectionStateDisconnected:
log.Printf("Peer Connection %s, cleaning up...\n", s.String())
cleanup()
}
})
peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
log.Printf("ICE Connection State has changed: %s\n", connectionState.String())
})
log.Printf("[%s] Setting remote description", requestID)
if err = peerConnection.SetRemoteDescription(offerReq.Offer); err != nil {
log.Printf("[%s] Failed to set remote description: %v", requestID, err)
cleanup()
http.Error(w, fmt.Sprintf("Failed to set remote description: %v", err), http.StatusInternalServerError)
return
}
log.Printf("[%s] Remote description set successfully", requestID)
log.Printf("[%s] Creating answer", requestID)
answer, err := peerConnection.CreateAnswer(nil)
if err != nil {
log.Printf("[%s] Failed to create answer: %v", requestID, err)
cleanup()
http.Error(w, fmt.Sprintf("Failed to create answer: %v", err), http.StatusInternalServerError)
return
}
log.Printf("[%s] Answer created successfully", requestID)
log.Printf("[%s] Setting local description", requestID)
if err = peerConnection.SetLocalDescription(answer); err != nil {
log.Printf("[%s] Failed to set local description: %v", requestID, err)
cleanup()
http.Error(w, fmt.Sprintf("Failed to set local description: %v", err), http.StatusInternalServerError)
return
}
log.Printf("[%s] Local description set successfully", requestID)
gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
select {
case <-gatherComplete:
log.Println("ICE gathering completed")
case <-time.After(3 * time.Second):
log.Println("ICE gathering timed out")
cleanup()
http.Error(w, "ICE gathering timed out", http.StatusInternalServerError)
return
}
response := peerConnection.LocalDescription()
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
log.Println("Failed to encode response:", err)
cleanup()
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
wg.Add(1)
go func() {
defer wg.Done()
defer func() {
select {
case <-cmdDone:
default:
close(cmdDone)
}
// 释放流资源
<-activeStreams
streamsMutex.Lock()
activeStreamCount--
activeStreamsGauge.Set(float64(activeStreamCount))
streamsMutex.Unlock()
}()
// 使用带缓冲的管道来控制数据流
bufferPool := sync.Pool{
New: func() interface{} {
return make([]byte, 1024*1024)
},
}
startTime := time.Now()
defer func() {
duration := time.Since(startTime).Seconds()
streamLatencyHistogram.WithLabelValues(offerReq.StreamConfig.Name).Observe(duration)
}()
cmdMutex.Lock()
cmd = exec.CommandContext(ctx, "ffmpeg",
"-i", offerReq.StreamConfig.URL,
"-c:v", "libx264",
"-preset", "ultrafast",
"-tune", "zerolatency",
"-profile:v", "baseline",
"-level", "3.0",
"-x264-params", "keyint=30:min-keyint=30:scenecut=0:bframes=0",
"-b:v", "1000k",
"-maxrate", "1500k",
"-bufsize", "2000k",
"-r", "25",
"-g", "30",
"-threads", "4",
"-f", "h264",
"-pix_fmt", "yuv420p",
"-vf", "scale=640:480",
"-movflags", "+faststart",
"-")
cmdMutex.Unlock()
// 添加错误输出捕获
var errBuf bytes.Buffer
cmd.Stderr = &errBuf
ffmpegStdout, err := cmd.StdoutPipe()
if err != nil {
log.Printf("Failed to create stdout pipe: %v", err)
return
}
if err := cmd.Start(); err != nil {
ffmpegErrorsCounter.WithLabelValues(offerReq.StreamConfig.Name).Inc()
log.Printf("Failed to start FFmpeg: %v\nFFmpeg error: %s", err, errBuf.String())
return
}
reader := bufio.NewReaderSize(ffmpegStdout, 8*1024*1024) // 使用8MB缓冲区
var lastFrameTime time.Time
// 在 FFmpeg 处理部分添加帧控制
targetFPS := cfg.FFmpeg.FrameRate
frameInterval := time.Second / time.Duration(targetFPS)
// 在视频处理循环中添加帧率控制
for {
select {
case <-ctx.Done():
log.Printf("[%s] Context cancelled", streamID)
return
case <-done:
log.Printf("[%s] Done signal received", streamID)
return
default:
// 控制帧率
elapsed := time.Since(lastFrameTime)
if elapsed < frameInterval {
time.Sleep(frameInterval - elapsed)
}
lastFrameTime = time.Now()
buffer := bufferPool.Get().([]byte)
n, err := reader.Read(buffer)
if err != nil {
bufferPool.Put(buffer)
if err != io.EOF {
log.Printf("[%s] Error reading from FFmpeg: %v", streamID, err)
}
return
}
writeCtx, writeCancel := context.WithTimeout(ctx, cfg.Server.WriteTimeout)
select {
case <-writeCtx.Done():
writeCancel()
bufferPool.Put(buffer)
log.Printf("[%s] Write timeout", streamID)
return
default:
err = videoTrack.WriteSample(media.Sample{
Data: buffer[:n],
Duration: frameInterval,
})
writeCancel()
bufferPool.Put(buffer)
if err != nil {
log.Printf("[%s] Error writing video sample: %v", streamID, err)
return
}
}
}
}
// 在函数结束时清理 FFmpeg 进程
defer func() {
cmdMutex.Lock()
if cmd != nil && cmd.Process != nil {
if err := cmd.Process.Kill(); err != nil {
if !strings.Contains(err.Error(), "process already finished") {
log.Printf("Error killing process: %v", err)
}
}
cmd.Wait()
}
cmdMutex.Unlock()
}()
}()
})
type HealthStatus struct {
Status string `json:"status"`
ActiveStreams int `json:"active_streams"`
MaxStreams int `json:"max_streams"`
MemStats runtime.MemStats `json:"mem_stats"`
FFmpegStatus string `json:"ffmpeg_status"`
UptimeSeconds int64 `json:"uptime_seconds"`
Version string `json:"version"`
}
var startTime = time.Now()
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
streamsMutex.RLock()
health := HealthStatus{
Status: "healthy",
ActiveStreams: activeStreamCount,
MaxStreams: cfg.FFmpeg.MaxConcurrentStreams,
UptimeSeconds: int64(time.Since(startTime).Seconds()),
Version: "1.0.0", // 添加版本信息
}
streamsMutex.RUnlock()
// 检查 FFmpeg 是否可用
cmd := exec.Command("ffmpeg", "-version")
if err := cmd.Run(); err != nil {
health.FFmpegStatus = "unavailable"
health.Status = "degraded"
} else {
health.FFmpegStatus = "available"
}
runtime.ReadMemStats(&health.MemStats)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(health)
})
// 修改中间件链
handler := chainMiddleware(
mux,
recoveryMiddleware,
corsMiddleware, // 添加 CORS 中间件
rateLimitMiddleware,
metricsMiddleware,
securityHeaders,
)
server := &http.Server{
Addr: cfg.Server.Port,
Handler: handler,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
MaxHeaderBytes: cfg.Server.MaxHeaderBytes,
}
go func() {
<-ctx.Done()
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
server.Shutdown(shutdownCtx)
}()
log.Printf("WebRTC Server running at %s\n", cfg.Server.Port)
if err := server.ListenAndServe(); err != http.ErrServerClosed {
log.Fatal("Failed to start HTTP server:", err)
}
}
func validateStreamConfig(cfg StreamConfig) error {
if cfg.Name == "" {
return fmt.Errorf("stream name is required")
}
if cfg.URL == "" {
return fmt.Errorf("stream URL is required")
}
// 验证 URL 格式
if _, err := url.Parse(cfg.URL); err != nil {
return fmt.Errorf("invalid stream URL: %v", err)
}
return nil
}
// 添加速率限制中间件
func rateLimitMiddleware(next http.Handler) http.Handler {
limiter := rate.NewLimiter(rate.Every(time.Second), 100) // 每秒100个请求
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !limiter.Allow() {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// 在处理请求前先测试 RTSP 流是否可访问
func testRTSPStream(url string) error {
cmd := exec.Command("ffprobe",
"-v", "error",
"-select_streams", "v:0", // 只选择第一个视频流
"-show_entries", "stream=width,height,codec_name,r_frame_rate",
"-of", "json",
"-i", url)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to connect to RTSP stream: %v, output: %s", err, string(output))
}
log.Printf("RTSP stream info: %s", string(output))
return nil
}
// 添加中间件链函数
func chainMiddleware(handler http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler {
for _, middleware := range middlewares {
handler = middleware(handler)
}
return handler
}
// 添加请求追踪中间件
func metricsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 包装 ResponseWriter 以捕获状态码
wrapped := &responseWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(wrapped, r)
duration := time.Since(start).Seconds()
requestDurationHistogram.WithLabelValues(
r.URL.Path,
r.Method,
fmt.Sprintf("%d", wrapped.status),
).Observe(duration)
// 更新内存使用指标
var m runtime.MemStats
runtime.ReadMemStats(&m)
memoryUsageGauge.Set(float64(m.Alloc))
})
}
// 添加 ResponseWriter 包装器
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}