diff --git a/agent/agent_windows.go b/agent/agent_windows.go index c9f8935..75b148c 100644 --- a/agent/agent_windows.go +++ b/agent/agent_windows.go @@ -161,22 +161,37 @@ func (a *Agent) RunScript(code string, shell string, args []string, timeout int, defer cancel() var timedOut = false + var token *wintoken.Token + var envBlock *uint16 + usingEnvVars := len(envVars) > 0 cmd := exec.Command(exe, cmdArgs...) if runasuser { - token, err := wintoken.GetInteractiveToken(wintoken.TokenImpersonation) + token, err = wintoken.GetInteractiveToken(wintoken.TokenImpersonation) if err == nil { defer token.Close() cmd.SysProcAttr = &syscall.SysProcAttr{Token: syscall.Token(token.Token()), HideWindow: true} + + if usingEnvVars { + envBlock, err = CreateEnvironmentBlock(syscall.Token(token.Token())) + if err == nil { + defer DestroyEnvironmentBlock(envBlock) + userEnv := EnvironmentBlockToSlice(envBlock) + cmd.Env = userEnv + } else { + cmd.Env = os.Environ() + } + } } + } else if usingEnvVars { + cmd.Env = os.Environ() + } + + if usingEnvVars { + cmd.Env = append(cmd.Env, envVars...) } cmd.Stdout = &outb cmd.Stderr = &errb - if len(envVars) > 0 { - cmd.Env = os.Environ() - cmd.Env = append(cmd.Env, envVars...) - } - if cmdErr := cmd.Start(); cmdErr != nil { a.Logger.Debugln(cmdErr) return "", cmdErr.Error(), 65, cmdErr diff --git a/agent/syscall_windows.go b/agent/syscall_windows.go index 26173ef..e7eccfd 100644 --- a/agent/syscall_windows.go +++ b/agent/syscall_windows.go @@ -24,11 +24,14 @@ var _ unsafe.Pointer var ( modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + userenv = windows.NewLazyDLL("userenv.dll") procFormatMessageW = modkernel32.NewProc("FormatMessageW") procGetOldestEventLogRecord = modadvapi32.NewProc("GetOldestEventLogRecord") procLoadLibraryExW = modkernel32.NewProc("LoadLibraryExW") procReadEventLogW = modadvapi32.NewProc("ReadEventLogW") + procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock") + procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock") ) // https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-eventlogrecord @@ -114,3 +117,47 @@ func ReadEventLog(eventLog w32.HANDLE, readFlags ReadFlag, recordOffset uint32, } return } + +func CreateEnvironmentBlock(token syscall.Token) (*uint16, error) { + var envBlock *uint16 + + ret, _, err := procCreateEnvironmentBlock.Call( + uintptr(unsafe.Pointer(&envBlock)), + uintptr(token), + 0, + ) + if ret == 0 { + return nil, err + } + + return envBlock, nil +} + +func DestroyEnvironmentBlock(envBlock *uint16) error { + ret, _, err := procDestroyEnvironmentBlock.Call(uintptr(unsafe.Pointer(envBlock))) + if ret == 0 { + return err + } + return nil +} + +func EnvironmentBlockToSlice(envBlock *uint16) []string { + var envs []string + + for { + len := 0 + for *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(envBlock)) + uintptr(len*2))) != 0 { + len++ + } + + if len == 0 { + break + } + + env := syscall.UTF16ToString((*[1 << 29]uint16)(unsafe.Pointer(envBlock))[:len]) + envs = append(envs, env) + envBlock = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(envBlock)) + uintptr((len+1)*2))) + } + + return envs +}