完善自我更新逻辑
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -4,50 +4,85 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"runtime"
|
||||||
|
"sentinel/pkg/config"
|
||||||
"sentinel/pkg/log"
|
"sentinel/pkg/log"
|
||||||
model2 "sentinel/pkg/model"
|
model2 "sentinel/pkg/model"
|
||||||
"sentinel/pkg/utils"
|
"sentinel/pkg/utils"
|
||||||
|
"strconv"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BusinessService struct {
|
type BusinessService struct {
|
||||||
mqtt *MQTTService
|
mqtt *MQTTService
|
||||||
deviceID string
|
deviceID string
|
||||||
project string
|
deptId string
|
||||||
cmdTopic string
|
cmdTopic string
|
||||||
deviceType string
|
deviceType string
|
||||||
|
subscriptions map[string]struct{} // 记录已订阅 topic
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBusinessService(m *MQTTService, project, deviceType, deviceID string) *BusinessService {
|
func NewBusinessService(m *MQTTService, deviceID string) *BusinessService {
|
||||||
// 根据统一规则生成 topic
|
|
||||||
cmdTopic := project + "/cmd/" + deviceType + "/" + deviceID + "/#"
|
|
||||||
return &BusinessService{
|
return &BusinessService{
|
||||||
mqtt: m,
|
mqtt: m,
|
||||||
project: project,
|
deviceID: deviceID,
|
||||||
deviceID: deviceID,
|
subscriptions: make(map[string]struct{}),
|
||||||
cmdTopic: cmdTopic,
|
|
||||||
deviceType: deviceType,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SubscribeTopic 订阅指定 topic,并记录可取消
|
||||||
|
func (b *BusinessService) SubscribeTopic(topic string, qos byte) error {
|
||||||
|
if err := b.mqtt.Subscribe(topic, qos); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.subscriptions[topic] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInitTopic(deviceID string) string {
|
||||||
|
return "+/+/+/" + deviceID + "/#"
|
||||||
|
}
|
||||||
|
func (b *BusinessService) getOwnTopic(deviceID string) string {
|
||||||
|
return b.deptId + "/cmd/" + b.deviceType + "/" + deviceID + "/#"
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BusinessService) Start() error {
|
func (b *BusinessService) Start() error {
|
||||||
// 订阅 cmd topic
|
if err := b.SubscribeTopic(getInitTopic(b.deviceID), 1); err != nil {
|
||||||
if err := b.mqtt.Subscribe(b.cmdTopic, 1); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
b.mqtt.SetMessageHandler(b.onMQTTMessage)
|
b.mqtt.SetMessageHandler(b.onMQTTMessage)
|
||||||
|
|
||||||
// 第一次连接就发送状态信息
|
// 第一次连接就发送状态信息
|
||||||
b.SendStatusInfo()
|
b.SendStatusInfo()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnsubscribeTopic 取消订阅指定 topic
|
||||||
|
func (b *BusinessService) UnsubscribeTopic(topic string) error {
|
||||||
|
token := b.mqtt.client.Unsubscribe(topic)
|
||||||
|
if token.Wait() && token.Error() != nil {
|
||||||
|
return token.Error()
|
||||||
|
}
|
||||||
|
delete(b.subscriptions, topic)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsubscribeAll 取消所有已订阅 topic
|
||||||
|
func (b *BusinessService) UnsubscribeAll() {
|
||||||
|
for topic := range b.subscriptions {
|
||||||
|
_ = b.mqtt.client.Unsubscribe(topic)
|
||||||
|
delete(b.subscriptions, topic)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 消息处理
|
// 消息处理
|
||||||
func (b *BusinessService) onMQTTMessage(topic string, payload []byte) {
|
func (b *BusinessService) onMQTTMessage(topic string, payload []byte) {
|
||||||
model := model2.FromStringToMqttTopic(topic)
|
model := model2.FromStringToMqttTopic(topic)
|
||||||
// 是指令
|
// 指令
|
||||||
if model.Domain == "cmd" {
|
if model.Domain == "cmd" && model.DeviceType == b.deviceType {
|
||||||
log.Println("收到指令:", model.Resource)
|
log.Println("收到指令:", model.Resource)
|
||||||
switch model.Resource {
|
switch model.Resource {
|
||||||
case "ping":
|
case "ping":
|
||||||
@@ -61,14 +96,26 @@ func (b *BusinessService) onMQTTMessage(topic string, payload []byte) {
|
|||||||
default:
|
default:
|
||||||
log.Println("未知的命令:", model.Resource)
|
log.Println("未知的命令:", model.Resource)
|
||||||
}
|
}
|
||||||
|
} else if model.Domain == "status" && model.Resource == "receipt" {
|
||||||
|
b.deviceType = model.DeviceType
|
||||||
|
b.deptId = model.DeptId
|
||||||
|
// 取消订阅之前的初始化主题
|
||||||
|
if b.UnsubscribeTopic(getInitTopic(b.deviceID)) != nil {
|
||||||
|
log.Error("无法取消初始化主题")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 新订阅属于自己的主题
|
||||||
|
if b.SubscribeTopic(b.getOwnTopic(b.deviceID), 1) != nil {
|
||||||
|
log.Error("无法定于属于自己的主题")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println("设备初始化成功:所属项目:", model.DeptId, "\t设备类型:", model.DeviceType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BusinessService) SendStatusInfo() {
|
func (b *BusinessService) SendStatusInfo() {
|
||||||
info := map[string]interface{}{
|
info := map[string]interface{}{
|
||||||
"project": utils.PROJECT,
|
"version": config.APP_VERSION,
|
||||||
"deviceType": utils.DEVICE_TPYE,
|
|
||||||
"version": utils.APP_VERSION,
|
|
||||||
"online": true,
|
"online": true,
|
||||||
"ip": utils.GetLocalIP(),
|
"ip": utils.GetLocalIP(),
|
||||||
"hostname": utils.GetHostname(),
|
"hostname": utils.GetHostname(),
|
||||||
@@ -81,14 +128,15 @@ func (b *BusinessService) SendStatusInfo() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
payload, _ := json.Marshal(info)
|
payload, _ := json.Marshal(info)
|
||||||
topic := b.project + "/status/" + b.deviceType + "/" + b.deviceID + "/info"
|
topic := "x/status/x/" + b.deviceID + "/info"
|
||||||
qos := byte(1)
|
qos := byte(1)
|
||||||
retained := true
|
retained := true
|
||||||
|
log.Println("发送消息:", topic)
|
||||||
|
|
||||||
if err := b.mqtt.Publish(topic, qos, retained, payload); err != nil {
|
if err := b.mqtt.Publish(topic, qos, retained, payload); err != nil {
|
||||||
log.Println("[BUS] failed to send status info:", err)
|
log.Println("发送状态信息出错:", err)
|
||||||
} else {
|
} else {
|
||||||
log.Println("[BUS] status info sent:", string(payload))
|
log.Println("发送状态信息:", string(payload))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,25 +155,37 @@ func (b *BusinessService) handleRestart() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新程序
|
// handleCheckUpdate 触发更新流程(主程序侧)
|
||||||
func (b *BusinessService) handleCheckUpdate() {
|
func (b *BusinessService) handleCheckUpdate() {
|
||||||
exe, _ := os.Executable()
|
|
||||||
updaterPath := filepath.Join(filepath.Dir(exe), "updater")
|
args := []string{
|
||||||
if _, err := os.Stat(updaterPath); os.IsNotExist(err) {
|
"--version", strconv.Itoa(config.APP_VERSION),
|
||||||
if _, err2 := os.Stat(updaterPath + ".exe"); err2 == nil {
|
|
||||||
updaterPath = updaterPath + ".exe"
|
|
||||||
} else {
|
|
||||||
log.Println("[BUS] updater not found")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
cmd := exec.Command(updaterPath, "--target", exe)
|
|
||||||
|
cmd := exec.Command("./updater.exe", args...)
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
// OS 级脱离父进程
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
if err := cmd.Start(); err != nil {
|
||||||
log.Println("[BUS] failed to start updater:", err)
|
log.Println("[BUS] failed to start updater:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Println("[BUS] exiting main program for update")
|
|
||||||
|
log.Println(
|
||||||
|
"[BUS] updater started (pid=%d), exiting main program\n",
|
||||||
|
cmd.Process.Pid,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 给 updater 留出启动窗口(尤其是 systemd / docker 环境)
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|||||||
+21
-13
@@ -2,41 +2,49 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sentinel/pkg/utils"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"sentinel/pkg/config"
|
||||||
"sentinel/pkg/device"
|
"sentinel/pkg/device"
|
||||||
"sentinel/pkg/log"
|
"sentinel/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
deviceID := device.GetDeviceID()
|
banner := `
|
||||||
log.Init(utils.Log_file_dic) // 初始化日志目录
|
==========================================================================
|
||||||
log.Info("Device id: " + deviceID) // 第一次启动记录
|
_______ _______ _ __________________ _ _______ _
|
||||||
|
( ____ \( ____ \( ( /|\__ __/\__ __/( ( /|( ____ \( \
|
||||||
|
| (_____ | (__ | \ | | | | | | | \ | || (__ | |
|
||||||
|
(_____ )| __) | (\ \) | | | | | | (\ \) || __) | |
|
||||||
|
/\____) || (____/\| ) \ | | | ___) (___| ) \ || (____/\| (____/\
|
||||||
|
\_______)(_______/|/ )_) )_( \_______/|/ )_)(_______/(_______/
|
||||||
|
==========================================================================
|
||||||
|
`
|
||||||
|
|
||||||
broker := fmt.Sprintf("tls://%s:%d", utils.MQTT_HOST, utils.MQTT_PORT)
|
fmt.Println(banner)
|
||||||
username := deviceID
|
deviceID := device.GetDeviceID()
|
||||||
password := utils.PASSWORD
|
log.Init(config.Log_file_dic) // 初始化日志目录
|
||||||
|
log.Info("Device id: " + deviceID) // 第一次启动记录
|
||||||
|
log.Println("版本号: ", config.APP_VERSION) // 第一次启动记录
|
||||||
|
|
||||||
var mqttSvc *MQTTService
|
var mqttSvc *MQTTService
|
||||||
firstFail := true // 标记是否第一次失败
|
firstFail := true // 标记是否第一次失败
|
||||||
for {
|
for {
|
||||||
mqttSvc = NewMQTTService(broker, username, username, password, 60)
|
mqttSvc = NewMQTTService(config.MQTT_BROKER, deviceID, deviceID, config.PASSWORD, 60)
|
||||||
err := mqttSvc.Connect()
|
err := mqttSvc.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if firstFail {
|
if firstFail {
|
||||||
log.Error("物联网服务连接失败,请先注册设备. DeviceID: " + deviceID + " ")
|
log.Error("物联网服务连接失败,如未注册设备,请先注册: " + deviceID)
|
||||||
firstFail = false
|
firstFail = false
|
||||||
}
|
}
|
||||||
time.Sleep(5 * time.Second) // 5秒后重试
|
time.Sleep(3 * time.Second) // 5秒后重试
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Info("物联网服务已启动")
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
defer mqttSvc.Close()
|
defer mqttSvc.Close()
|
||||||
|
|
||||||
biz := NewBusinessService(mqttSvc, utils.PROJECT, utils.DEVICE_TPYE, deviceID)
|
biz := NewBusinessService(mqttSvc, deviceID)
|
||||||
for {
|
for {
|
||||||
// MQTT业务
|
// MQTT业务
|
||||||
err := biz.Start()
|
err := biz.Start()
|
||||||
@@ -47,7 +55,7 @@ func main() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 个人业务
|
// 个人业务
|
||||||
test()
|
//test()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
// 变动
|
||||||
|
|
||||||
|
// 常量
|
||||||
|
const (
|
||||||
|
// 版本号
|
||||||
|
APP_VERSION = 1
|
||||||
|
|
||||||
|
Log_file_dic = "./logs"
|
||||||
|
MQTT_BROKER = "tls://ai.ronsunny.cn:8093"
|
||||||
|
PASSWORD = "123456"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DeviceType string
|
||||||
|
// DeptId string
|
||||||
|
)
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type MqttTopic struct {
|
type MqttTopic struct {
|
||||||
Project string
|
DeptId string
|
||||||
Domain string
|
Domain string
|
||||||
DeviceType string
|
DeviceType string
|
||||||
DeviceID string
|
DeviceID string
|
||||||
@@ -21,7 +21,7 @@ func FromStringToMqttTopic(topic string) *MqttTopic {
|
|||||||
parts = append(parts, "")
|
parts = append(parts, "")
|
||||||
}
|
}
|
||||||
return &MqttTopic{
|
return &MqttTopic{
|
||||||
Project: parts[0],
|
DeptId: parts[0],
|
||||||
Domain: parts[1],
|
Domain: parts[1],
|
||||||
DeviceType: parts[2],
|
DeviceType: parts[2],
|
||||||
DeviceID: parts[3],
|
DeviceID: parts[3],
|
||||||
@@ -38,7 +38,7 @@ func (m *MqttTopic) ToString() string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
return strings.Join([]string{
|
return strings.Join([]string{
|
||||||
toVal(m.Project),
|
toVal(m.DeptId),
|
||||||
toVal(m.Domain),
|
toVal(m.Domain),
|
||||||
toVal(m.DeviceType),
|
toVal(m.DeviceType),
|
||||||
toVal(m.DeviceID),
|
toVal(m.DeviceID),
|
||||||
@@ -48,7 +48,7 @@ func (m *MqttTopic) ToString() string {
|
|||||||
|
|
||||||
// 严格生成 topic,不允许 "+" 或空
|
// 严格生成 topic,不允许 "+" 或空
|
||||||
func (m *MqttTopic) Build() (string, error) {
|
func (m *MqttTopic) Build() (string, error) {
|
||||||
parts := []string{m.Project, m.Domain, m.DeviceType, m.DeviceID, m.Resource}
|
parts := []string{m.DeptId, m.Domain, m.DeviceType, m.DeviceID, m.Resource}
|
||||||
for _, p := range parts {
|
for _, p := range parts {
|
||||||
if p == "" || p == "+" {
|
if p == "" || p == "+" {
|
||||||
return "", errors.New("cannot build strict topic, wildcard exists")
|
return "", errors.New("cannot build strict topic, wildcard exists")
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package model
|
|||||||
type UpdateInfo struct {
|
type UpdateInfo struct {
|
||||||
Version int `json:"version"`
|
Version int `json:"version"`
|
||||||
DownloadURL string `json:"url"`
|
DownloadURL string `json:"url"`
|
||||||
Notes bool `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
// 变动
|
|
||||||
|
|
||||||
// 常量
|
|
||||||
const (
|
|
||||||
// 版本号
|
|
||||||
APP_VERSION = 0
|
|
||||||
|
|
||||||
Log_file_dic = "./logs"
|
|
||||||
MQTT_HOST = "ai.ronsunny.cn"
|
|
||||||
MQTT_PORT = 8093
|
|
||||||
PASSWORD = "123456"
|
|
||||||
PROJECT = "sentinel"
|
|
||||||
DEVICE_TPYE = "edge"
|
|
||||||
)
|
|
||||||
+108
-42
@@ -1,83 +1,149 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sentinel/pkg/log"
|
"runtime"
|
||||||
"sentinel/pkg/utils"
|
|
||||||
|
|
||||||
"sentinel/pkg/device"
|
"sentinel/pkg/device"
|
||||||
|
"sentinel/pkg/log"
|
||||||
"sentinel/pkg/net"
|
"sentinel/pkg/net"
|
||||||
|
"strconv"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
deviceID := device.GetDeviceID()
|
// 定义命令行参数
|
||||||
fmt.Printf("[updater] device id: %s\n", deviceID)
|
version := flag.String("version", "", "current version of main program")
|
||||||
|
|
||||||
exeDir, _ := os.Executable()
|
flag.Parse()
|
||||||
target := filepath.Join(filepath.Dir(exeDir), "main_program_binary_name") // TODO: 替换
|
|
||||||
|
|
||||||
if err := RunUpdate(deviceID, target); err != nil {
|
if *version == "" {
|
||||||
log.Fatalf("[updater] update failed: %v", err)
|
// updater 视角:-1 表示“未知版本”,一定触发更新检测
|
||||||
|
*version = "0"
|
||||||
|
log.Println("[updater] --version not provided, fallback to -1")
|
||||||
|
fmt.Println("[updater] 主程序版本号:", *version)
|
||||||
}
|
}
|
||||||
fmt.Println("[updater] update finished")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunUpdate 检查更新、下载、替换主程序并启动新程序
|
deviceID := device.GetDeviceID()
|
||||||
func RunUpdate(deviceID string, targetExe string) error {
|
fmt.Printf("[updater] 当前设备id: %s\n", deviceID)
|
||||||
|
versionInt, err := strconv.Atoi(*version)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("[updater] invalid version:", *version)
|
||||||
|
versionInt = 0
|
||||||
|
}
|
||||||
|
if err := RunUpdate(deviceID, versionInt); err != nil {
|
||||||
|
log.Fatalf("[updater] 更新失败: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Println("[updater] 更新程序结束")
|
||||||
|
}
|
||||||
|
func RunUpdate(deviceID string, version int) error {
|
||||||
|
// 1. 检查更新
|
||||||
info, err := api.CheckUpdate(deviceID)
|
info, err := api.CheckUpdate(deviceID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("[updater] 请求错误,请检查网络")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("[updater] 新版本:", info.Version)
|
||||||
|
fmt.Println("[updater] 新内容:", info.Notes)
|
||||||
|
fmt.Println("[updater] 下载地址:", info.DownloadURL)
|
||||||
|
|
||||||
|
// 获取主程序路径
|
||||||
|
selfPath, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// 2. 比对本地版本
|
selfDir := filepath.Dir(selfPath)
|
||||||
if info.Version <= utils.APP_VERSION {
|
targetExe := filepath.Join(selfDir, "main.exe") // Windows 固定名,可根据实际改
|
||||||
fmt.Println("[updater] already latest version:", utils.APP_VERSION)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
fmt.Println("[updater] updating to version:", info.Version, "notes:", info.Notes)
|
|
||||||
|
|
||||||
// 3. 下载新版本到临时目录
|
// 2. 对比版本号,没有新版本则直接启动原程序
|
||||||
tmpFile := filepath.Join(os.TempDir(), "new_program_tmp")
|
if info.Version <= version {
|
||||||
out, err := os.Create(tmpFile)
|
fmt.Println("[updater] 暂未发现新版本,启动原程序")
|
||||||
|
cmd := exec.Command(targetExe)
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 有新版本则先备份到 ./tmp/old_app/
|
||||||
|
backupDir := filepath.Join(selfDir, "tmp", "old_app")
|
||||||
|
_ = os.MkdirAll(backupDir, 0755)
|
||||||
|
backupFile := filepath.Join(backupDir, "main_"+strconv.Itoa(version)+".bak")
|
||||||
|
if err := os.Rename(targetExe, backupFile); err != nil {
|
||||||
|
fmt.Println("[updater] 备份失败,但继续更新:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 下载新版本到 ./tmp
|
||||||
|
tmpDir := filepath.Join(selfDir, "tmp")
|
||||||
|
_ = os.MkdirAll(tmpDir, 0755)
|
||||||
|
|
||||||
|
u, err := url.Parse(info.DownloadURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create temp file failed: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
defer out.Close()
|
base := filepath.Base(u.Path)
|
||||||
|
ext := filepath.Ext(base)
|
||||||
|
|
||||||
resp2, err := http.Get(info.DownloadURL)
|
tmpFile, err := os.CreateTemp(tmpDir, "main_*"+ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("download failed: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
defer resp2.Body.Close()
|
defer tmpFile.Close()
|
||||||
|
|
||||||
h := sha256.New()
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
mw := io.MultiWriter(out, h)
|
resp, err := client.Get(info.DownloadURL)
|
||||||
if _, err := io.Copy(mw, resp2.Body); err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("write temp file failed: %w", err)
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 替换 targetExe
|
// 5. 重命名新文件到 ./main.exe
|
||||||
backup := targetExe + ".bak"
|
tmpFile.Close() // 关闭临时文件才能重命名
|
||||||
_ = os.Remove(backup)
|
maxRetry := 20
|
||||||
_ = os.Rename(targetExe, backup) // 备份旧版本
|
for i := 0; i < maxRetry; i++ {
|
||||||
if err := os.Rename(tmpFile, targetExe); err != nil {
|
err := os.Rename(tmpFile.Name(), targetExe)
|
||||||
return fmt.Errorf("replace main program failed: %w", err)
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fmt.Println("[updater] 文件被占用,等待 500ms 再尝试...")
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
if i == maxRetry-1 {
|
||||||
|
return fmt.Errorf("替换失败: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
fmt.Println("[updater] replaced main program")
|
|
||||||
|
|
||||||
// 5. 启动新主程序
|
// 6. 启动主程序,同时完全退出自己
|
||||||
cmd := exec.Command(targetExe)
|
cmd := exec.Command(targetExe)
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := cmd.Start(); err != nil {
|
if err := cmd.Start(); err != nil {
|
||||||
return fmt.Errorf("start new program failed: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("[updater] new program started successfully")
|
fmt.Printf("[updater] 更新完成,新程序已启动 (pid=%d),退出更新程序\n", cmd.Process.Pid)
|
||||||
|
os.Exit(0)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user