diff --git a/go/domain/join.go b/go/domain/join.go index 10166c43..53868b23 100644 --- a/go/domain/join.go +++ b/go/domain/join.go @@ -25,17 +25,6 @@ import ( "golang.org/x/sys/windows" ) -const ( - // Known error codes - // nErrSuccess is the value returned on a successful run of NetJoinDomain. - // https://docs.microsoft.com/en-us/windows/win32/api/lmjoin/nf-lmjoin-netjoindomain#return-value - nErrSuccess = 0 - - // Domain Join constants - // https://learn.microsoft.com/en-us/windows/win32/api/lmjoin/nf-lmjoin-netjoindomain#parameters - server = 0 // use the local machine name -) - // DomainJoinOptions specifies options for the domain join operation. // See https://learn.microsoft.com/en-us/windows/win32/api/lmjoin/nf-lmjoin-netjoindomain type DomainJoinOptions uint32 @@ -73,7 +62,8 @@ var ( netJoinDomain = prodNetJoinDomain.Call ) -// Domain joins a client to a domain. +// Domain joins the local machine to a domain. +// See https://learn.microsoft.com/en-us/windows/win32/api/lmjoin/nf-lmjoin-netjoindomain for more details. func Domain(domain, joinOU, joinAccount, joinPassword string, options DomainJoinOptions) error { dom, err := windows.UTF16PtrFromString(domain) @@ -84,23 +74,28 @@ func Domain(domain, joinOU, joinAccount, joinPassword string, options DomainJoin if err != nil { return err } - ou, err := windows.UTF16PtrFromString(joinOU) - if err != nil { - return err + var ou *uint16 + if joinOU != "" { + var err error + ou, err = windows.UTF16PtrFromString(joinOU) + if err != nil { + return err + } } pw, err := windows.UTF16PtrFromString(joinPassword) if err != nil { return err } fmt.Printf("Attempting domain join with domain: %s, OU: %s, account: %s\n", domain, joinOU, joinAccount) + // https://learn.microsoft.com/en-us/windows/win32/api/lmjoin/nf-lmjoin-netjoindomain#parameters if returnCode, _, _ := netJoinDomain( - server, // lpServer + 0, // lpServer, 0 / Null means use the local machine. uintptr(unsafe.Pointer(dom)), // lpDomain uintptr(unsafe.Pointer(ou)), // lpMachineAccountOU uintptr(unsafe.Pointer(acc)), // lpAccount uintptr(unsafe.Pointer(pw)), // lpPassword uintptr(options), // fJoinOptions - ); returnCode != nErrSuccess { + ); windows.Errno(returnCode) != windows.ERROR_SUCCESS { return fmt.Errorf("failed to join domain: %w", windows.Errno(returnCode)) } diff --git a/go/domain/join_test.go b/go/domain/join_test.go index b3abd529..22f8f2d2 100644 --- a/go/domain/join_test.go +++ b/go/domain/join_test.go @@ -42,7 +42,7 @@ func TestDomain(t *testing.T) { }{ { name: "success", - retCode: nErrSuccess, + retCode: 0, }, { name: "failure",