完善自我更新逻辑

This commit is contained in:
BBIT-Kai
2025-12-30 17:54:33 +08:00
parent ec29d883bd
commit 62e8ecb7d6
12 changed files with 248 additions and 112 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+96 -36
View File
@@ -4,50 +4,85 @@ import (
"encoding/json" "encoding/json"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "runtime"
"sentinel/pkg/config"
"sentinel/pkg/log" "sentinel/pkg/log"
model2 "sentinel/pkg/model" model2 "sentinel/pkg/model"
"sentinel/pkg/utils" "sentinel/pkg/utils"
"strconv"
"syscall"
"time" "time"
) )
type BusinessService struct { type BusinessService struct {
mqtt *MQTTService mqtt *MQTTService
deviceID string deviceID string
project string deptId string
cmdTopic string cmdTopic string
deviceType string deviceType string
subscriptions map[string]struct{} // 记录已订阅 topic
} }
func NewBusinessService(m *MQTTService, project, deviceType, deviceID string) *BusinessService { func NewBusinessService(m *MQTTService, deviceID string) *BusinessService {
// 根据统一规则生成 topic
cmdTopic := project + "/cmd/" + deviceType + "/" + deviceID + "/#"
return &BusinessService{ return &BusinessService{
mqtt: m, mqtt: m,
project: project, deviceID: deviceID,
deviceID: deviceID, subscriptions: make(map[string]struct{}),
cmdTopic: cmdTopic,
deviceType: deviceType,
} }
} }
// SubscribeTopic 订阅指定 topic,并记录可取消
func (b *BusinessService) SubscribeTopic(topic string, qos byte) error {
if err := b.mqtt.Subscribe(topic, qos); err != nil {
return err
}
b.subscriptions[topic] = struct{}{}
return nil
}
func getInitTopic(deviceID string) string {
return "+/+/+/" + deviceID + "/#"
}
func (b *BusinessService) getOwnTopic(deviceID string) string {
return b.deptId + "/cmd/" + b.deviceType + "/" + deviceID + "/#"
}
func (b *BusinessService) Start() error { func (b *BusinessService) Start() error {
// 订阅 cmd topic if err := b.SubscribeTopic(getInitTopic(b.deviceID), 1); err != nil {
if err := b.mqtt.Subscribe(b.cmdTopic, 1); err != nil {
return err return err
} }
b.mqtt.SetMessageHandler(b.onMQTTMessage) b.mqtt.SetMessageHandler(b.onMQTTMessage)
// 第一次连接就发送状态信息 // 第一次连接就发送状态信息
b.SendStatusInfo() b.SendStatusInfo()
return nil return nil
} }
// UnsubscribeTopic 取消订阅指定 topic
func (b *BusinessService) UnsubscribeTopic(topic string) error {
token := b.mqtt.client.Unsubscribe(topic)
if token.Wait() && token.Error() != nil {
return token.Error()
}
delete(b.subscriptions, topic)
return nil
}
// UnsubscribeAll 取消所有已订阅 topic
func (b *BusinessService) UnsubscribeAll() {
for topic := range b.subscriptions {
_ = b.mqtt.client.Unsubscribe(topic)
delete(b.subscriptions, topic)
}
}
// 消息处理 // 消息处理
func (b *BusinessService) onMQTTMessage(topic string, payload []byte) { func (b *BusinessService) onMQTTMessage(topic string, payload []byte) {
model := model2.FromStringToMqttTopic(topic) model := model2.FromStringToMqttTopic(topic)
// 指令 // 指令
if model.Domain == "cmd" { if model.Domain == "cmd" && model.DeviceType == b.deviceType {
log.Println("收到指令:", model.Resource) log.Println("收到指令:", model.Resource)
switch model.Resource { switch model.Resource {
case "ping": case "ping":
@@ -61,14 +96,26 @@ func (b *BusinessService) onMQTTMessage(topic string, payload []byte) {
default: default:
log.Println("未知的命令:", model.Resource) log.Println("未知的命令:", model.Resource)
} }
} else if model.Domain == "status" && model.Resource == "receipt" {
b.deviceType = model.DeviceType
b.deptId = model.DeptId
// 取消订阅之前的初始化主题
if b.UnsubscribeTopic(getInitTopic(b.deviceID)) != nil {
log.Error("无法取消初始化主题")
return
}
// 新订阅属于自己的主题
if b.SubscribeTopic(b.getOwnTopic(b.deviceID), 1) != nil {
log.Error("无法定于属于自己的主题")
return
}
log.Println("设备初始化成功:所属项目:", model.DeptId, "\t设备类型:", model.DeviceType)
} }
} }
func (b *BusinessService) SendStatusInfo() { func (b *BusinessService) SendStatusInfo() {
info := map[string]interface{}{ info := map[string]interface{}{
"project": utils.PROJECT, "version": config.APP_VERSION,
"deviceType": utils.DEVICE_TPYE,
"version": utils.APP_VERSION,
"online": true, "online": true,
"ip": utils.GetLocalIP(), "ip": utils.GetLocalIP(),
"hostname": utils.GetHostname(), "hostname": utils.GetHostname(),
@@ -81,14 +128,15 @@ func (b *BusinessService) SendStatusInfo() {
} }
payload, _ := json.Marshal(info) payload, _ := json.Marshal(info)
topic := b.project + "/status/" + b.deviceType + "/" + b.deviceID + "/info" topic := "x/status/x/" + b.deviceID + "/info"
qos := byte(1) qos := byte(1)
retained := true retained := true
log.Println("发送消息:", topic)
if err := b.mqtt.Publish(topic, qos, retained, payload); err != nil { if err := b.mqtt.Publish(topic, qos, retained, payload); err != nil {
log.Println("[BUS] failed to send status info:", err) log.Println("发送状态信息出错:", err)
} else { } else {
log.Println("[BUS] status info sent:", string(payload)) log.Println("发送状态信息:", string(payload))
} }
} }
@@ -107,25 +155,37 @@ func (b *BusinessService) handleRestart() {
os.Exit(0) os.Exit(0)
} }
// 更新程序 // handleCheckUpdate 触发更新流程(主程序侧)
func (b *BusinessService) handleCheckUpdate() { func (b *BusinessService) handleCheckUpdate() {
exe, _ := os.Executable()
updaterPath := filepath.Join(filepath.Dir(exe), "updater") args := []string{
if _, err := os.Stat(updaterPath); os.IsNotExist(err) { "--version", strconv.Itoa(config.APP_VERSION),
if _, err2 := os.Stat(updaterPath + ".exe"); err2 == nil {
updaterPath = updaterPath + ".exe"
} else {
log.Println("[BUS] updater not found")
return
}
} }
cmd := exec.Command(updaterPath, "--target", exe)
cmd := exec.Command("./updater.exe", args...)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
// OS 级脱离父进程
switch runtime.GOOS {
case "windows":
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
}
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
log.Println("[BUS] failed to start updater:", err) log.Println("[BUS] failed to start updater:", err)
return return
} }
log.Println("[BUS] exiting main program for update")
log.Println(
"[BUS] updater started (pid=%d), exiting main program\n",
cmd.Process.Pid,
)
// 给 updater 留出启动窗口(尤其是 systemd / docker 环境)
time.Sleep(500 * time.Millisecond)
os.Exit(0) os.Exit(0)
} }
+21 -13
View File
@@ -2,41 +2,49 @@ package main
import ( import (
"fmt" "fmt"
"sentinel/pkg/utils"
"time" "time"
"sentinel/pkg/config"
"sentinel/pkg/device" "sentinel/pkg/device"
"sentinel/pkg/log" "sentinel/pkg/log"
) )
func main() { func main() {
deviceID := device.GetDeviceID() banner := `
log.Init(utils.Log_file_dic) // 初始化日志目录 ==========================================================================
log.Info("Device id: " + deviceID) // 第一次启动记录 _______ _______ _ __________________ _ _______ _
( ____ \( ____ \( ( /|\__ __/\__ __/( ( /|( ____ \( \
| (_____ | (__ | \ | | | | | | | \ | || (__ | |
(_____ )| __) | (\ \) | | | | | | (\ \) || __) | |
/\____) || (____/\| ) \ | | | ___) (___| ) \ || (____/\| (____/\
\_______)(_______/|/ )_) )_( \_______/|/ )_)(_______/(_______/
==========================================================================
`
broker := fmt.Sprintf("tls://%s:%d", utils.MQTT_HOST, utils.MQTT_PORT) fmt.Println(banner)
username := deviceID deviceID := device.GetDeviceID()
password := utils.PASSWORD log.Init(config.Log_file_dic) // 初始化日志目录
log.Info("Device id: " + deviceID) // 第一次启动记录
log.Println("版本号: ", config.APP_VERSION) // 第一次启动记录
var mqttSvc *MQTTService var mqttSvc *MQTTService
firstFail := true // 标记是否第一次失败 firstFail := true // 标记是否第一次失败
for { for {
mqttSvc = NewMQTTService(broker, username, username, password, 60) mqttSvc = NewMQTTService(config.MQTT_BROKER, deviceID, deviceID, config.PASSWORD, 60)
err := mqttSvc.Connect() err := mqttSvc.Connect()
if err != nil { if err != nil {
if firstFail { if firstFail {
log.Error("物联网服务连接失败,请先注册设备. DeviceID: " + deviceID + " ") log.Error("物联网服务连接失败,如未注册设备,请先注册: " + deviceID)
firstFail = false firstFail = false
} }
time.Sleep(5 * time.Second) // 5秒后重试 time.Sleep(3 * time.Second) // 5秒后重试
continue continue
} }
log.Info("物联网服务已启动")
break break
} }
defer mqttSvc.Close() defer mqttSvc.Close()
biz := NewBusinessService(mqttSvc, utils.PROJECT, utils.DEVICE_TPYE, deviceID) biz := NewBusinessService(mqttSvc, deviceID)
for { for {
// MQTT业务 // MQTT业务
err := biz.Start() err := biz.Start()
@@ -47,7 +55,7 @@ func main() {
continue continue
} }
// 个人业务 // 个人业务
test() //test()
break break
} }
+18
View File
@@ -0,0 +1,18 @@
package config
// 变动
// 常量
const (
// 版本号
APP_VERSION = 1
Log_file_dic = "./logs"
MQTT_BROKER = "tls://ai.ronsunny.cn:8093"
PASSWORD = "123456"
)
var (
// DeviceType string
// DeptId string
)
+4 -4
View File
@@ -6,7 +6,7 @@ import (
) )
type MqttTopic struct { type MqttTopic struct {
Project string DeptId string
Domain string Domain string
DeviceType string DeviceType string
DeviceID string DeviceID string
@@ -21,7 +21,7 @@ func FromStringToMqttTopic(topic string) *MqttTopic {
parts = append(parts, "") parts = append(parts, "")
} }
return &MqttTopic{ return &MqttTopic{
Project: parts[0], DeptId: parts[0],
Domain: parts[1], Domain: parts[1],
DeviceType: parts[2], DeviceType: parts[2],
DeviceID: parts[3], DeviceID: parts[3],
@@ -38,7 +38,7 @@ func (m *MqttTopic) ToString() string {
return s return s
} }
return strings.Join([]string{ return strings.Join([]string{
toVal(m.Project), toVal(m.DeptId),
toVal(m.Domain), toVal(m.Domain),
toVal(m.DeviceType), toVal(m.DeviceType),
toVal(m.DeviceID), toVal(m.DeviceID),
@@ -48,7 +48,7 @@ func (m *MqttTopic) ToString() string {
// 严格生成 topic,不允许 "+" 或空 // 严格生成 topic,不允许 "+" 或空
func (m *MqttTopic) Build() (string, error) { func (m *MqttTopic) Build() (string, error) {
parts := []string{m.Project, m.Domain, m.DeviceType, m.DeviceID, m.Resource} parts := []string{m.DeptId, m.Domain, m.DeviceType, m.DeviceID, m.Resource}
for _, p := range parts { for _, p := range parts {
if p == "" || p == "+" { if p == "" || p == "+" {
return "", errors.New("cannot build strict topic, wildcard exists") return "", errors.New("cannot build strict topic, wildcard exists")
+1 -1
View File
@@ -3,5 +3,5 @@ package model
type UpdateInfo struct { type UpdateInfo struct {
Version int `json:"version"` Version int `json:"version"`
DownloadURL string `json:"url"` DownloadURL string `json:"url"`
Notes bool `json:"notes"` Notes string `json:"notes"`
} }
-16
View File
@@ -1,16 +0,0 @@
package utils
// 变动
// 常量
const (
// 版本号
APP_VERSION = 0
Log_file_dic = "./logs"
MQTT_HOST = "ai.ronsunny.cn"
MQTT_PORT = 8093
PASSWORD = "123456"
PROJECT = "sentinel"
DEVICE_TPYE = "edge"
)
+108 -42
View File
@@ -1,83 +1,149 @@
package main package main
import ( import (
"crypto/sha256" "flag"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"sentinel/pkg/log" "runtime"
"sentinel/pkg/utils"
"sentinel/pkg/device" "sentinel/pkg/device"
"sentinel/pkg/log"
"sentinel/pkg/net" "sentinel/pkg/net"
"strconv"
"syscall"
"time"
) )
func main() { func main() {
deviceID := device.GetDeviceID() // 定义命令行参数
fmt.Printf("[updater] device id: %s\n", deviceID) version := flag.String("version", "", "current version of main program")
exeDir, _ := os.Executable() flag.Parse()
target := filepath.Join(filepath.Dir(exeDir), "main_program_binary_name") // TODO: 替换
if err := RunUpdate(deviceID, target); err != nil { if *version == "" {
log.Fatalf("[updater] update failed: %v", err) // updater 视角:-1 表示“未知版本”,一定触发更新检测
*version = "0"
log.Println("[updater] --version not provided, fallback to -1")
fmt.Println("[updater] 主程序版本号:", *version)
} }
fmt.Println("[updater] update finished")
}
// RunUpdate 检查更新、下载、替换主程序并启动新程序 deviceID := device.GetDeviceID()
func RunUpdate(deviceID string, targetExe string) error { fmt.Printf("[updater] 当前设备id: %s\n", deviceID)
versionInt, err := strconv.Atoi(*version)
if err != nil {
log.Println("[updater] invalid version:", *version)
versionInt = 0
}
if err := RunUpdate(deviceID, versionInt); err != nil {
log.Fatalf("[updater] 更新失败: %v", err)
}
fmt.Println("[updater] 更新程序结束")
}
func RunUpdate(deviceID string, version int) error {
// 1. 检查更新
info, err := api.CheckUpdate(deviceID) info, err := api.CheckUpdate(deviceID)
if err != nil {
fmt.Println("[updater] 请求错误,请检查网络")
return err
}
fmt.Println("[updater] 新版本:", info.Version)
fmt.Println("[updater] 新内容:", info.Notes)
fmt.Println("[updater] 下载地址:", info.DownloadURL)
// 获取主程序路径
selfPath, err := os.Executable()
if err != nil { if err != nil {
return err return err
} }
// 2. 比对本地版本 selfDir := filepath.Dir(selfPath)
if info.Version <= utils.APP_VERSION { targetExe := filepath.Join(selfDir, "main.exe") // Windows 固定名,可根据实际改
fmt.Println("[updater] already latest version:", utils.APP_VERSION)
return nil
}
fmt.Println("[updater] updating to version:", info.Version, "notes:", info.Notes)
// 3. 下载新版本到临时目录 // 2. 对比版本号,没有新版本则直接启动原程序
tmpFile := filepath.Join(os.TempDir(), "new_program_tmp") if info.Version <= version {
out, err := os.Create(tmpFile) fmt.Println("[updater] 暂未发现新版本,启动原程序")
cmd := exec.Command(targetExe)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if runtime.GOOS == "windows" {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
}
if err := cmd.Start(); err != nil {
return err
}
os.Exit(0)
}
// 3. 有新版本则先备份到 ./tmp/old_app/
backupDir := filepath.Join(selfDir, "tmp", "old_app")
_ = os.MkdirAll(backupDir, 0755)
backupFile := filepath.Join(backupDir, "main_"+strconv.Itoa(version)+".bak")
if err := os.Rename(targetExe, backupFile); err != nil {
fmt.Println("[updater] 备份失败,但继续更新:", err)
}
// 4. 下载新版本到 ./tmp
tmpDir := filepath.Join(selfDir, "tmp")
_ = os.MkdirAll(tmpDir, 0755)
u, err := url.Parse(info.DownloadURL)
if err != nil { if err != nil {
return fmt.Errorf("create temp file failed: %w", err) return err
} }
defer out.Close() base := filepath.Base(u.Path)
ext := filepath.Ext(base)
resp2, err := http.Get(info.DownloadURL) tmpFile, err := os.CreateTemp(tmpDir, "main_*"+ext)
if err != nil { if err != nil {
return fmt.Errorf("download failed: %w", err) return err
} }
defer resp2.Body.Close() defer tmpFile.Close()
h := sha256.New() client := &http.Client{Timeout: 30 * time.Second}
mw := io.MultiWriter(out, h) resp, err := client.Get(info.DownloadURL)
if _, err := io.Copy(mw, resp2.Body); err != nil { if err != nil {
return fmt.Errorf("write temp file failed: %w", err) return err
}
defer resp.Body.Close()
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
return err
} }
// 4. 替换 targetExe // 5. 重命名新文件到 ./main.exe
backup := targetExe + ".bak" tmpFile.Close() // 关闭临时文件才能重命名
_ = os.Remove(backup) maxRetry := 20
_ = os.Rename(targetExe, backup) // 备份旧版本 for i := 0; i < maxRetry; i++ {
if err := os.Rename(tmpFile, targetExe); err != nil { err := os.Rename(tmpFile.Name(), targetExe)
return fmt.Errorf("replace main program failed: %w", err) if err == nil {
break
}
fmt.Println("[updater] 文件被占用,等待 500ms 再尝试...")
time.Sleep(500 * time.Millisecond)
if i == maxRetry-1 {
return fmt.Errorf("替换失败: %w", err)
}
} }
fmt.Println("[updater] replaced main program")
// 5. 启动主程序 // 6. 启动主程序,同时完全退出自己
cmd := exec.Command(targetExe) cmd := exec.Command(targetExe)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
if runtime.GOOS == "windows" {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
}
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
return fmt.Errorf("start new program failed: %w", err) return err
} }
fmt.Println("[updater] new program started successfully") fmt.Printf("[updater] 更新完成,新程序已启动 (pid=%d),退出更新程序\n", cmd.Process.Pid)
os.Exit(0)
return nil return nil
} }