完善自我更新逻辑

This commit is contained in:
BBIT-Kai
2025-12-30 17:54:33 +08:00
parent ec29d883bd
commit 62e8ecb7d6
12 changed files with 248 additions and 112 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+90 -30
View File
@@ -4,50 +4,85 @@ import (
"encoding/json"
"os"
"os/exec"
"path/filepath"
"runtime"
"sentinel/pkg/config"
"sentinel/pkg/log"
model2 "sentinel/pkg/model"
"sentinel/pkg/utils"
"strconv"
"syscall"
"time"
)
type BusinessService struct {
mqtt *MQTTService
deviceID string
project string
deptId string
cmdTopic string
deviceType string
subscriptions map[string]struct{} // 记录已订阅 topic
}
func NewBusinessService(m *MQTTService, project, deviceType, deviceID string) *BusinessService {
// 根据统一规则生成 topic
cmdTopic := project + "/cmd/" + deviceType + "/" + deviceID + "/#"
func NewBusinessService(m *MQTTService, deviceID string) *BusinessService {
return &BusinessService{
mqtt: m,
project: project,
deviceID: deviceID,
cmdTopic: cmdTopic,
deviceType: deviceType,
subscriptions: make(map[string]struct{}),
}
}
// SubscribeTopic 订阅指定 topic,并记录可取消
func (b *BusinessService) SubscribeTopic(topic string, qos byte) error {
if err := b.mqtt.Subscribe(topic, qos); err != nil {
return err
}
b.subscriptions[topic] = struct{}{}
return nil
}
func getInitTopic(deviceID string) string {
return "+/+/+/" + deviceID + "/#"
}
func (b *BusinessService) getOwnTopic(deviceID string) string {
return b.deptId + "/cmd/" + b.deviceType + "/" + deviceID + "/#"
}
func (b *BusinessService) Start() error {
// 订阅 cmd topic
if err := b.mqtt.Subscribe(b.cmdTopic, 1); err != nil {
if err := b.SubscribeTopic(getInitTopic(b.deviceID), 1); err != nil {
return err
}
b.mqtt.SetMessageHandler(b.onMQTTMessage)
// 第一次连接就发送状态信息
b.SendStatusInfo()
return nil
}
// UnsubscribeTopic 取消订阅指定 topic
func (b *BusinessService) UnsubscribeTopic(topic string) error {
token := b.mqtt.client.Unsubscribe(topic)
if token.Wait() && token.Error() != nil {
return token.Error()
}
delete(b.subscriptions, topic)
return nil
}
// UnsubscribeAll 取消所有已订阅 topic
func (b *BusinessService) UnsubscribeAll() {
for topic := range b.subscriptions {
_ = b.mqtt.client.Unsubscribe(topic)
delete(b.subscriptions, topic)
}
}
// 消息处理
func (b *BusinessService) onMQTTMessage(topic string, payload []byte) {
model := model2.FromStringToMqttTopic(topic)
// 指令
if model.Domain == "cmd" {
// 指令
if model.Domain == "cmd" && model.DeviceType == b.deviceType {
log.Println("收到指令:", model.Resource)
switch model.Resource {
case "ping":
@@ -61,14 +96,26 @@ func (b *BusinessService) onMQTTMessage(topic string, payload []byte) {
default:
log.Println("未知的命令:", model.Resource)
}
} else if model.Domain == "status" && model.Resource == "receipt" {
b.deviceType = model.DeviceType
b.deptId = model.DeptId
// 取消订阅之前的初始化主题
if b.UnsubscribeTopic(getInitTopic(b.deviceID)) != nil {
log.Error("无法取消初始化主题")
return
}
// 新订阅属于自己的主题
if b.SubscribeTopic(b.getOwnTopic(b.deviceID), 1) != nil {
log.Error("无法定于属于自己的主题")
return
}
log.Println("设备初始化成功:所属项目:", model.DeptId, "\t设备类型:", model.DeviceType)
}
}
func (b *BusinessService) SendStatusInfo() {
info := map[string]interface{}{
"project": utils.PROJECT,
"deviceType": utils.DEVICE_TPYE,
"version": utils.APP_VERSION,
"version": config.APP_VERSION,
"online": true,
"ip": utils.GetLocalIP(),
"hostname": utils.GetHostname(),
@@ -81,14 +128,15 @@ func (b *BusinessService) SendStatusInfo() {
}
payload, _ := json.Marshal(info)
topic := b.project + "/status/" + b.deviceType + "/" + b.deviceID + "/info"
topic := "x/status/x/" + b.deviceID + "/info"
qos := byte(1)
retained := true
log.Println("发送消息:", topic)
if err := b.mqtt.Publish(topic, qos, retained, payload); err != nil {
log.Println("[BUS] failed to send status info:", err)
log.Println("发送状态信息出错:", err)
} else {
log.Println("[BUS] status info sent:", string(payload))
log.Println("发送状态信息:", string(payload))
}
}
@@ -107,25 +155,37 @@ func (b *BusinessService) handleRestart() {
os.Exit(0)
}
// 更新程序
// handleCheckUpdate 触发更新流程(主程序侧)
func (b *BusinessService) handleCheckUpdate() {
exe, _ := os.Executable()
updaterPath := filepath.Join(filepath.Dir(exe), "updater")
if _, err := os.Stat(updaterPath); os.IsNotExist(err) {
if _, err2 := os.Stat(updaterPath + ".exe"); err2 == nil {
updaterPath = updaterPath + ".exe"
} else {
log.Println("[BUS] updater not found")
return
args := []string{
"--version", strconv.Itoa(config.APP_VERSION),
}
}
cmd := exec.Command(updaterPath, "--target", exe)
cmd := exec.Command("./updater.exe", args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// OS 级脱离父进程
switch runtime.GOOS {
case "windows":
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
}
if err := cmd.Start(); err != nil {
log.Println("[BUS] failed to start updater:", err)
return
}
log.Println("[BUS] exiting main program for update")
log.Println(
"[BUS] updater started (pid=%d), exiting main program\n",
cmd.Process.Pid,
)
// 给 updater 留出启动窗口(尤其是 systemd / docker 环境)
time.Sleep(500 * time.Millisecond)
os.Exit(0)
}
+21 -13
View File
@@ -2,41 +2,49 @@ package main
import (
"fmt"
"sentinel/pkg/utils"
"time"
"sentinel/pkg/config"
"sentinel/pkg/device"
"sentinel/pkg/log"
)
func main() {
deviceID := device.GetDeviceID()
log.Init(utils.Log_file_dic) // 初始化日志目录
log.Info("Device id: " + deviceID) // 第一次启动记录
banner := `
==========================================================================
_______ _______ _ __________________ _ _______ _
( ____ \( ____ \( ( /|\__ __/\__ __/( ( /|( ____ \( \
| (_____ | (__ | \ | | | | | | | \ | || (__ | |
(_____ )| __) | (\ \) | | | | | | (\ \) || __) | |
/\____) || (____/\| ) \ | | | ___) (___| ) \ || (____/\| (____/\
\_______)(_______/|/ )_) )_( \_______/|/ )_)(_______/(_______/
==========================================================================
`
broker := fmt.Sprintf("tls://%s:%d", utils.MQTT_HOST, utils.MQTT_PORT)
username := deviceID
password := utils.PASSWORD
fmt.Println(banner)
deviceID := device.GetDeviceID()
log.Init(config.Log_file_dic) // 初始化日志目录
log.Info("Device id: " + deviceID) // 第一次启动记录
log.Println("版本号: ", config.APP_VERSION) // 第一次启动记录
var mqttSvc *MQTTService
firstFail := true // 标记是否第一次失败
for {
mqttSvc = NewMQTTService(broker, username, username, password, 60)
mqttSvc = NewMQTTService(config.MQTT_BROKER, deviceID, deviceID, config.PASSWORD, 60)
err := mqttSvc.Connect()
if err != nil {
if firstFail {
log.Error("物联网服务连接失败,请先注册设备. DeviceID: " + deviceID + " ")
log.Error("物联网服务连接失败,如未注册设备,请先注册: " + deviceID)
firstFail = false
}
time.Sleep(5 * time.Second) // 5秒后重试
time.Sleep(3 * time.Second) // 5秒后重试
continue
}
log.Info("物联网服务已启动")
break
}
defer mqttSvc.Close()
biz := NewBusinessService(mqttSvc, utils.PROJECT, utils.DEVICE_TPYE, deviceID)
biz := NewBusinessService(mqttSvc, deviceID)
for {
// MQTT业务
err := biz.Start()
@@ -47,7 +55,7 @@ func main() {
continue
}
// 个人业务
test()
//test()
break
}
+18
View File
@@ -0,0 +1,18 @@
package config
// 变动
// 常量
const (
// 版本号
APP_VERSION = 1
Log_file_dic = "./logs"
MQTT_BROKER = "tls://ai.ronsunny.cn:8093"
PASSWORD = "123456"
)
var (
// DeviceType string
// DeptId string
)
+4 -4
View File
@@ -6,7 +6,7 @@ import (
)
type MqttTopic struct {
Project string
DeptId string
Domain string
DeviceType string
DeviceID string
@@ -21,7 +21,7 @@ func FromStringToMqttTopic(topic string) *MqttTopic {
parts = append(parts, "")
}
return &MqttTopic{
Project: parts[0],
DeptId: parts[0],
Domain: parts[1],
DeviceType: parts[2],
DeviceID: parts[3],
@@ -38,7 +38,7 @@ func (m *MqttTopic) ToString() string {
return s
}
return strings.Join([]string{
toVal(m.Project),
toVal(m.DeptId),
toVal(m.Domain),
toVal(m.DeviceType),
toVal(m.DeviceID),
@@ -48,7 +48,7 @@ func (m *MqttTopic) ToString() string {
// 严格生成 topic,不允许 "+" 或空
func (m *MqttTopic) Build() (string, error) {
parts := []string{m.Project, m.Domain, m.DeviceType, m.DeviceID, m.Resource}
parts := []string{m.DeptId, m.Domain, m.DeviceType, m.DeviceID, m.Resource}
for _, p := range parts {
if p == "" || p == "+" {
return "", errors.New("cannot build strict topic, wildcard exists")
+1 -1
View File
@@ -3,5 +3,5 @@ package model
type UpdateInfo struct {
Version int `json:"version"`
DownloadURL string `json:"url"`
Notes bool `json:"notes"`
Notes string `json:"notes"`
}
-16
View File
@@ -1,16 +0,0 @@
package utils
// 变动
// 常量
const (
// 版本号
APP_VERSION = 0
Log_file_dic = "./logs"
MQTT_HOST = "ai.ronsunny.cn"
MQTT_PORT = 8093
PASSWORD = "123456"
PROJECT = "sentinel"
DEVICE_TPYE = "edge"
)
+118 -52
View File
@@ -1,83 +1,149 @@
package main
import (
"crypto/sha256"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"sentinel/pkg/log"
"sentinel/pkg/utils"
"runtime"
"sentinel/pkg/device"
"sentinel/pkg/log"
"sentinel/pkg/net"
"strconv"
"syscall"
"time"
)
func main() {
deviceID := device.GetDeviceID()
fmt.Printf("[updater] device id: %s\n", deviceID)
// 定义命令行参数
version := flag.String("version", "", "current version of main program")
exeDir, _ := os.Executable()
target := filepath.Join(filepath.Dir(exeDir), "main_program_binary_name") // TODO: 替换
flag.Parse()
if err := RunUpdate(deviceID, target); err != nil {
log.Fatalf("[updater] update failed: %v", err)
if *version == "" {
// updater 视角:-1 表示“未知版本”,一定触发更新检测
*version = "0"
log.Println("[updater] --version not provided, fallback to -1")
fmt.Println("[updater] 主程序版本号:", *version)
}
fmt.Println("[updater] update finished")
}
// RunUpdate 检查更新、下载、替换主程序并启动新程序
func RunUpdate(deviceID string, targetExe string) error {
deviceID := device.GetDeviceID()
fmt.Printf("[updater] 当前设备id: %s\n", deviceID)
versionInt, err := strconv.Atoi(*version)
if err != nil {
log.Println("[updater] invalid version:", *version)
versionInt = 0
}
if err := RunUpdate(deviceID, versionInt); err != nil {
log.Fatalf("[updater] 更新失败: %v", err)
}
fmt.Println("[updater] 更新程序结束")
}
func RunUpdate(deviceID string, version int) error {
// 1. 检查更新
info, err := api.CheckUpdate(deviceID)
if err != nil {
fmt.Println("[updater] 请求错误,请检查网络")
return err
}
fmt.Println("[updater] 新版本:", info.Version)
fmt.Println("[updater] 新内容:", info.Notes)
fmt.Println("[updater] 下载地址:", info.DownloadURL)
// 获取主程序路径
selfPath, err := os.Executable()
if err != nil {
return err
}
// 2. 比对本地版本
if info.Version <= utils.APP_VERSION {
fmt.Println("[updater] already latest version:", utils.APP_VERSION)
return nil
}
fmt.Println("[updater] updating to version:", info.Version, "notes:", info.Notes)
selfDir := filepath.Dir(selfPath)
targetExe := filepath.Join(selfDir, "main.exe") // Windows 固定名,可根据实际改
// 3. 下载新版本到临时目录
tmpFile := filepath.Join(os.TempDir(), "new_program_tmp")
out, err := os.Create(tmpFile)
if err != nil {
return fmt.Errorf("create temp file failed: %w", err)
}
defer out.Close()
resp2, err := http.Get(info.DownloadURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
defer resp2.Body.Close()
h := sha256.New()
mw := io.MultiWriter(out, h)
if _, err := io.Copy(mw, resp2.Body); err != nil {
return fmt.Errorf("write temp file failed: %w", err)
}
// 4. 替换 targetExe
backup := targetExe + ".bak"
_ = os.Remove(backup)
_ = os.Rename(targetExe, backup) // 备份旧版本
if err := os.Rename(tmpFile, targetExe); err != nil {
return fmt.Errorf("replace main program failed: %w", err)
}
fmt.Println("[updater] replaced main program")
// 5. 启动新主程序
// 2. 对比版本号,没有新版本则直接启动原程序
if info.Version <= version {
fmt.Println("[updater] 暂未发现新版本,启动原程序")
cmd := exec.Command(targetExe)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if runtime.GOOS == "windows" {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start new program failed: %w", err)
return err
}
os.Exit(0)
}
fmt.Println("[updater] new program started successfully")
// 3. 有新版本则先备份到 ./tmp/old_app/
backupDir := filepath.Join(selfDir, "tmp", "old_app")
_ = os.MkdirAll(backupDir, 0755)
backupFile := filepath.Join(backupDir, "main_"+strconv.Itoa(version)+".bak")
if err := os.Rename(targetExe, backupFile); err != nil {
fmt.Println("[updater] 备份失败,但继续更新:", err)
}
// 4. 下载新版本到 ./tmp
tmpDir := filepath.Join(selfDir, "tmp")
_ = os.MkdirAll(tmpDir, 0755)
u, err := url.Parse(info.DownloadURL)
if err != nil {
return err
}
base := filepath.Base(u.Path)
ext := filepath.Ext(base)
tmpFile, err := os.CreateTemp(tmpDir, "main_*"+ext)
if err != nil {
return err
}
defer tmpFile.Close()
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(info.DownloadURL)
if err != nil {
return err
}
defer resp.Body.Close()
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
return err
}
// 5. 重命名新文件到 ./main.exe
tmpFile.Close() // 关闭临时文件才能重命名
maxRetry := 20
for i := 0; i < maxRetry; i++ {
err := os.Rename(tmpFile.Name(), targetExe)
if err == nil {
break
}
fmt.Println("[updater] 文件被占用,等待 500ms 再尝试...")
time.Sleep(500 * time.Millisecond)
if i == maxRetry-1 {
return fmt.Errorf("替换失败: %w", err)
}
}
// 6. 启动主程序,同时完全退出自己
cmd := exec.Command(targetExe)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if runtime.GOOS == "windows" {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
}
if err := cmd.Start(); err != nil {
return err
}
fmt.Printf("[updater] 更新完成,新程序已启动 (pid=%d),退出更新程序\n", cmd.Process.Pid)
os.Exit(0)
return nil
}