fix crash when update fails to download fixes #33

This commit is contained in:
wh1te909 2023-04-29 16:56:15 -07:00
parent 588a4bcbf7
commit 4f01e214fd
3 changed files with 18 additions and 13 deletions

View File

@ -16,6 +16,7 @@ package agent
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -218,12 +219,12 @@ func (a *Agent) seEnforcing() bool {
return out.Status.Exit == 0 && strings.Contains(out.Stdout, "Enforcing") return out.Status.Exit == 0 && strings.Contains(out.Stdout, "Enforcing")
} }
func (a *Agent) AgentUpdate(url, inno, version string) { func (a *Agent) AgentUpdate(url, inno, version string) error {
self, err := os.Executable() self, err := os.Executable()
if err != nil { if err != nil {
a.Logger.Errorln("AgentUpdate() os.Executable():", err) a.Logger.Errorln("AgentUpdate() os.Executable():", err)
return return err
} }
// more reliable method to get current working directory than os.Getwd() // more reliable method to get current working directory than os.Getwd()
@ -233,7 +234,7 @@ func (a *Agent) AgentUpdate(url, inno, version string) {
f, err := os.CreateTemp(cwd, "trmm") f, err := os.CreateTemp(cwd, "trmm")
if err != nil { if err != nil {
a.Logger.Errorln("AgentUpdate() os.CreateTemp:", err) a.Logger.Errorln("AgentUpdate() os.CreateTemp:", err)
return return err
} }
defer os.Remove(f.Name()) defer os.Remove(f.Name())
@ -252,12 +253,12 @@ func (a *Agent) AgentUpdate(url, inno, version string) {
if err != nil { if err != nil {
a.Logger.Errorln("AgentUpdate() download:", err) a.Logger.Errorln("AgentUpdate() download:", err)
f.Close() f.Close()
return return err
} }
if r.IsError() { if r.IsError() {
a.Logger.Errorln("AgentUpdate() status code:", r.StatusCode()) a.Logger.Errorln("AgentUpdate() status code:", r.StatusCode())
f.Close() f.Close()
return return errors.New("err")
} }
f.Close() f.Close()
@ -265,7 +266,7 @@ func (a *Agent) AgentUpdate(url, inno, version string) {
err = os.Rename(f.Name(), self) err = os.Rename(f.Name(), self)
if err != nil { if err != nil {
a.Logger.Errorln("AgentUpdate() os.Rename():", err) a.Logger.Errorln("AgentUpdate() os.Rename():", err)
return return err
} }
if runtime.GOOS == "linux" && a.seEnforcing() { if runtime.GOOS == "linux" && a.seEnforcing() {
@ -283,10 +284,11 @@ func (a *Agent) AgentUpdate(url, inno, version string) {
case "darwin": case "darwin":
opts.Command = "launchctl kickstart -k system/tacticalagent" opts.Command = "launchctl kickstart -k system/tacticalagent"
default: default:
return return nil
} }
a.CmdV2(opts) a.CmdV2(opts)
return nil
} }
func (a *Agent) AgentUninstall(code string) { func (a *Agent) AgentUninstall(code string) {

View File

@ -601,7 +601,7 @@ func (a *Agent) UninstallCleanup() {
os.RemoveAll(a.WinRunAsUserTmpDir) os.RemoveAll(a.WinRunAsUserTmpDir)
} }
func (a *Agent) AgentUpdate(url, inno, version string) { func (a *Agent) AgentUpdate(url, inno, version string) error {
time.Sleep(time.Duration(randRange(1, 15)) * time.Second) time.Sleep(time.Duration(randRange(1, 15)) * time.Second)
a.KillHungUpdates() a.KillHungUpdates()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@ -620,13 +620,11 @@ func (a *Agent) AgentUpdate(url, inno, version string) {
r, err := rClient.R().SetOutput(updater).Get(url) r, err := rClient.R().SetOutput(updater).Get(url)
if err != nil { if err != nil {
a.Logger.Errorln(err) a.Logger.Errorln(err)
CMD("net", []string{"start", winSvcName}, 10, false) return err
return
} }
if r.IsError() { if r.IsError() {
a.Logger.Errorln("Download failed with status code", r.StatusCode()) a.Logger.Errorln("Download failed with status code", r.StatusCode())
CMD("net", []string{"start", winSvcName}, 10, false) return err
return
} }
innoLogFile := filepath.Join(a.WinTmpDir, fmt.Sprintf("tacticalagent_update_v%s.txt", version)) innoLogFile := filepath.Join(a.WinTmpDir, fmt.Sprintf("tacticalagent_update_v%s.txt", version))
@ -638,6 +636,7 @@ func (a *Agent) AgentUpdate(url, inno, version string) {
} }
cmd.Start() cmd.Start()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
return nil
} }
func (a *Agent) osString() string { func (a *Agent) osString() string {

View File

@ -481,7 +481,11 @@ func (a *Agent) RunRPC() {
} else { } else {
ret.Encode("ok") ret.Encode("ok")
msg.Respond(resp) msg.Respond(resp)
a.AgentUpdate(p.Data["url"], p.Data["inno"], p.Data["version"]) err := a.AgentUpdate(p.Data["url"], p.Data["inno"], p.Data["version"])
if err != nil {
atomic.StoreUint32(&agentUpdateLocker, 0)
return
}
atomic.StoreUint32(&agentUpdateLocker, 0) atomic.StoreUint32(&agentUpdateLocker, 0)
nc.Flush() nc.Flush()
nc.Close() nc.Close()