diff --git a/.gitignore b/.gitignore index 8b13789..e69de29 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +0,0 @@ - diff --git a/client/client.go b/client/client.go index daf1e07..c539305 100644 --- a/client/client.go +++ b/client/client.go @@ -1,17 +1,16 @@ package client import ( - "log" - "github.com/keyvchan/NetAssist/pkg/flags" "github.com/keyvchan/NetAssist/pkg/utils" "github.com/keyvchan/NetAssist/protocol" + "github.com/rs/zerolog/log" ) // Req is the entry point for the client func Req() { - types := flags.GetArg(2) - log.Println("Req:", types) + types := flags.Config.Protocol + log.Info().Msg("Req: " + types) switch types { case "tcp": protocol.TCPClient() @@ -26,6 +25,6 @@ func Req() { case "ip": utils.Unimplemented("ip") default: - log.Fatal("unknown protocol", types) + log.Error().Msg("unknown protocol " + types) } } diff --git a/go.mod b/go.mod index 648d079..675d7e6 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,16 @@ module github.com/keyvchan/NetAssist go 1.18 + +require ( + github.com/rs/zerolog v1.28.0 + github.com/spf13/cobra v1.6.1 +) + +require ( + github.com/inconshreveable/mousetrap v1.0.1 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c8cd03b --- /dev/null +++ b/go.sum @@ -0,0 +1,23 @@ +github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc= +github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY= +github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= +github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 h1:foEbQz/B0Oz6YIqu/69kfXPYeFQAuuMYFkjaqXzl5Wo= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 00428b8..a750562 100644 --- a/main.go +++ b/main.go @@ -1,22 +1,87 @@ package main import ( - "log" + "os" "github.com/keyvchan/NetAssist/client" "github.com/keyvchan/NetAssist/pkg/flags" "github.com/keyvchan/NetAssist/server" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" ) -func main() { - flags.SetArgs() - types := flags.GetArg(1) - switch types { - case "server": - server.Serve() - case "client": - client.Req() - default: - log.Fatal("Unknown type: ", types) +var ( + rootCmd = &cobra.Command{ + Use: "NetAssit", + Short: "NetAssit is a network debugging and testing tool", + } + + serverCmd = &cobra.Command{ + Use: "server", + Short: "Start a server", + Run: func(_ *cobra.Command, _ []string) { + // pass the config to it + flags.Config.Type = "server" + + // concat the protocol and host + server.Serve() + + }, + } + + clientCmd = &cobra.Command{ + Use: "client", + Short: "Start a client", + Run: func(_ *cobra.Command, _ []string) { + // set type + flags.Config.Type = "client" + client.Req() + }, + } +) + +func execute() { + if err := rootCmd.Execute(); err != nil { + os.Exit(1) } } + +func initConfig() { + log.Info().Msg("initConfig") + +} + +func init() { + + // setup log + cobra.OnInitialize(initConfig) + + serverCmd.PersistentFlags().StringVarP(&flags.Config.Protocol, "protocol", "p", "tcp", "protocol") + serverCmd.PersistentFlags().IntVarP(&flags.Config.Port, "port", "P", 8080, "port") + serverCmd.PersistentFlags().StringVarP(&flags.Config.Host, "host", "H", "127.0.0.1", "host") + // protocol port host type is all required + serverCmd.MarkFlagsRequiredTogether("protocol") + serverCmd.MarkFlagsRequiredTogether("host", "port") + serverCmd.PersistentFlags().BoolVarP(&flags.Config.Binary, "binary", "b", false, "binary") + + rootCmd.AddCommand(serverCmd) + + clientCmd.PersistentFlags().StringVarP(&flags.Config.Protocol, "protocol", "p", "tcp", "protocol") + clientCmd.PersistentFlags().IntVarP(&flags.Config.Port, "port", "P", 8080, "port") + clientCmd.PersistentFlags().StringVarP(&flags.Config.Host, "host", "H", "127.0.0.1", "host") + // protocol port host type is all required + clientCmd.MarkFlagsRequiredTogether("protocol") + clientCmd.MarkFlagsRequiredTogether("port", "host") + clientCmd.PersistentFlags().BoolVarP(&flags.Config.Binary, "binary", "b", false, "binary") + rootCmd.AddCommand(clientCmd) + +} + +func main() { + // setup log + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + // set command line parsing + rootCmd.Execute() + +} diff --git a/pkg/connection/file.go b/pkg/connection/file.go index ed6ef0f..b3976e4 100644 --- a/pkg/connection/file.go +++ b/pkg/connection/file.go @@ -4,13 +4,12 @@ import ( "bufio" "bytes" "encoding/hex" - "errors" "fmt" - "log" "os" "github.com/keyvchan/NetAssist/pkg/flags" "github.com/keyvchan/NetAssist/pkg/message" + "github.com/rs/zerolog/log" ) // File represents a abstraced file connection. @@ -42,19 +41,19 @@ func (f File) ReadMessage() message.Message { } func ReadStdin(stdin *os.File) message.Message { - input_binary := flags.GetArg(4) + input_binary := flags.Config.Binary scanner := bufio.NewScanner(stdin) if scanner.Scan() { buf := []byte{} - if input_binary == "--binary" { + if input_binary { byte_slices := bytes.Split(scanner.Bytes(), []byte(" ")) for _, byte_slice := range byte_slices { new_byte := make([]byte, 1024) n, err := hex.Decode(new_byte, byte_slice) if err != nil { // hex parse error, ignore this byte - log.Println(err) + log.Err(err).Msg("Could not parse hex") continue } buf = append(buf, new_byte[:n]...) @@ -70,7 +69,7 @@ func ReadStdin(stdin *os.File) message.Message { Addr: nil, } } else { - log.Fatal(errors.New("failed to read from stdin")) + log.Error().Msg("failed to read from stdin") } return message.Message{} } @@ -84,13 +83,13 @@ func (f File) WriteMessage(msg message.Message) { } } -func WriteStdout(writter interface{}, message message.Message) { +func WriteStdout(_ interface{}, message message.Message) { // write to stdout - input_binary := flags.GetArg(4) - fmt.Println(message.Addr) + input_binary := flags.Config.Binary + log.Debug().Msg(message.String()) - if input_binary == "--binary" { + if input_binary { for i := 0; i < len(message.Content); i++ { fmt.Printf("%02x ", message.Content[i]) } diff --git a/pkg/connection/stream.go b/pkg/connection/stream.go index a682a2d..3f3b113 100644 --- a/pkg/connection/stream.go +++ b/pkg/connection/stream.go @@ -3,10 +3,11 @@ package connection import ( "errors" "io" - "log" + "net" "github.com/keyvchan/NetAssist/pkg/message" + "github.com/rs/zerolog/log" ) type Stream struct { @@ -25,19 +26,19 @@ func ReadConn(conn interface{}) message.Message { // type checking connn, ok := conn.(net.Conn) if !ok { - log.Fatal("Wrong type") + log.Error().Msg("Could not convert to net.Conn") } buf := make([]byte, 1024) // input_binary := GetArg(4) n, err := connn.Read(buf) if errors.Is(err, io.EOF) { - log.Println("Connection closed") + log.Error().Msg("Connection closed") // remove from channel *ClosedConn <- connn return message.Message{} } if err != nil { - log.Println(err) + log.Err(err) return message.Message{} } message := message.Message{ @@ -55,6 +56,6 @@ func (s Stream) WriteMessage(msg message.Message) { func WriteConn(conn net.Conn, message message.Message) { _, err := conn.Write(message.Content) if err != nil { - log.Fatal(err) + log.Err(err).Msg("Could not write message") } } diff --git a/pkg/flags/args.go b/pkg/flags/args.go index 688c43a..d821ff4 100644 --- a/pkg/flags/args.go +++ b/pkg/flags/args.go @@ -1,27 +1,34 @@ package flags -import ( - "log" - "os" -) +type Configuration struct { + Address string + Type string + Protocol string + Port int + Host string + Binary bool // binary transfer or text transfer + BinaryFile string // binary file path +} + +var Config = Configuration{} // args stores the arguments -var args []string +// var args []string // SetArgs retrieves the arguments from the command line -func SetArgs() { - args = os.Args -} +// func SetArgs() { +// args = os.Args +// } // GetArg returns the nth argument -func GetArg(i int) string { - if i < len(args) { - return args[i] - } else { - if i == 4 { - return "text" - } - log.Fatal("Index out of range") - } - return "" -} +// func GetArg(i int) string { +// if i < len(args) { +// return args[i] +// } else { +// if i == 4 { +// return "text" +// } +// log.Fatal("Index out of range") +// } +// return "" +// } diff --git a/pkg/message/passing.go b/pkg/message/passing.go index b35ac5e..7d9a320 100644 --- a/pkg/message/passing.go +++ b/pkg/message/passing.go @@ -1,8 +1,7 @@ package message import ( - "errors" - "log" + "github.com/rs/zerolog/log" ) // Read reads a message from the given reader @@ -13,7 +12,7 @@ func Read(message_chan chan Message, reader Reader) { if buf.Content != nil { message_chan <- buf } else { - log.Println(errors.New("could not read message"), reader) + log.Error().Msg("could not read message") } } diff --git a/protocol/tcp.go b/protocol/tcp.go index 29cdda5..7882ce3 100644 --- a/protocol/tcp.go +++ b/protocol/tcp.go @@ -1,21 +1,22 @@ package protocol import ( - "log" "net" "os" + "strconv" "github.com/keyvchan/NetAssist/pkg/connection" "github.com/keyvchan/NetAssist/pkg/flags" "github.com/keyvchan/NetAssist/pkg/message" + "github.com/rs/zerolog/log" ) // TCPServer is a TCP server, read from stdin and write to the client and read from the client write it to stdout func TCPServer() { - address := flags.GetArg(3) + address := flags.Config.Host + ":" + strconv.Itoa(flags.Config.Port) listener, err := net.Listen("tcp", address) if err != nil { - log.Fatal(err) + log.Err(err).Msg("Could not listen on address") } // store all connections in a slice // NOTE: Possibly race condition @@ -47,9 +48,9 @@ func accept_conn(read_chan chan message.Message, listener net.Listener, connecti for { conn, err := listener.Accept() if err != nil { - log.Fatal(err) + log.Err(err).Msg("Could not accept connection") } - log.Println("Accepted connection") + log.Info().Msg("Accepted connection") connections[conn] = true // create conn tcp_client := connection.Stream{ @@ -63,10 +64,10 @@ func accept_conn(read_chan chan message.Message, listener net.Listener, connecti // TCPClient is a TCP client, read from stdin and write to the server and read from the server when it to stdout func TCPClient() { - address := flags.GetArg(3) + address := flags.Config.Host + ":" + strconv.Itoa(flags.Config.Port) conn, err := net.Dial("tcp", address) if err != nil { - log.Fatal(err) + log.Err(err).Msg("Could not connect to server") } defer conn.Close() diff --git a/protocol/udp.go b/protocol/udp.go index aedbf2f..b64f278 100644 --- a/protocol/udp.go +++ b/protocol/udp.go @@ -3,6 +3,7 @@ package protocol import ( "log" "net" + "strconv" "github.com/keyvchan/NetAssist/pkg/connection" "github.com/keyvchan/NetAssist/pkg/flags" @@ -11,7 +12,7 @@ import ( // UDPServer is a UDP server, it reads from stdin and writes to stdout and read from the client and write to the stdout func UDPServer() { - address := flags.GetArg(3) + address := flags.Config.Host + ":" + strconv.Itoa(flags.Config.Port) conn, err := net.ListenPacket("udp", address) if err != nil { log.Fatal(err) @@ -33,7 +34,7 @@ func UDPServer() { // UDPClient is a UDP client, it reads from stdin and writes to stdout and read from the server and write to the stdout func UDPClient() { - address := flags.GetArg(3) + address := flags.Config.Host + ":" + strconv.Itoa(flags.Config.Port) conn, err := net.Dial("udp", address) if err != nil { log.Fatal(err) diff --git a/protocol/unix.go b/protocol/unix.go index 25bedc7..900b690 100644 --- a/protocol/unix.go +++ b/protocol/unix.go @@ -2,32 +2,32 @@ package protocol import ( "bufio" - "errors" "fmt" - "log" "net" "os" + "strconv" "github.com/keyvchan/NetAssist/pkg/flags" + "github.com/rs/zerolog/log" ) // UnixServer is a server for the unix socket, it bridged server to stdout and stdin to server func UnixServer() { - address := flags.GetArg(3) + address := flags.Config.Host + ":" + strconv.Itoa(flags.Config.Port) listener, err := net.Listen("unix", address) if err != nil { - log.Fatal(err) + log.Err(err).Msg("failed to listen on unix socket") } conn, err := listener.Accept() - log.Println("Accepted connection") + log.Info().Msg("Accepted connection") if err != nil { - log.Fatal(err) + log.Err(err).Msg("failed to accept connection") } buf := make([]byte, 1024) for { _, err = conn.Read(buf) if err != nil { - log.Fatal(err) + log.Err(err).Msg("failed to read from connection") } fmt.Println(string(buf)) } @@ -36,10 +36,10 @@ func UnixServer() { // UnixClient is a client for the unix socket, it bridged stdin to server and server to stdout func UnixClient() { - address := flags.GetArg(3) + address := flags.Config.Host + ":" + strconv.Itoa(flags.Config.Port) conn, err := net.Dial("unix", address) if err != nil { - log.Fatal(err) + log.Err(err).Msg("failed to dial unix socket") } defer conn.Close() for { @@ -47,7 +47,7 @@ func UnixClient() { if scanner.Scan() { conn.Write([]byte(scanner.Text())) } else { - log.Fatal(errors.New("failed to read from stdin")) + log.Error().Msg("failed to read from stdin") } } diff --git a/protocol/unixgram.go b/protocol/unixgram.go index f485381..a7a4b2f 100644 --- a/protocol/unixgram.go +++ b/protocol/unixgram.go @@ -7,13 +7,14 @@ import ( "log" "net" "os" + "strconv" "github.com/keyvchan/NetAssist/pkg/flags" ) // UnixgramServer is a server for the unixgram protocol, its create a bridge between server and client func UnixgramServer() { - address := flags.GetArg(3) + address := flags.Config.Host + ":" + strconv.Itoa(flags.Config.Port) conn, err := net.ListenPacket("unixgram", address) if err != nil { log.Fatal(err) @@ -27,7 +28,7 @@ func UnixgramServer() { // UnixgramClient is a client for the unixgram protocol, its create a bridge between server and client func UnixgramClient() { - address := flags.GetArg(3) + address := flags.Config.Host + ":" + strconv.Itoa(flags.Config.Port) conn, err := net.Dial("unixgram", address) if err != nil { log.Fatal(err) diff --git a/server/server.go b/server/server.go index 95b1249..3beb1b4 100644 --- a/server/server.go +++ b/server/server.go @@ -1,16 +1,15 @@ package server import ( - "log" - "github.com/keyvchan/NetAssist/pkg/flags" "github.com/keyvchan/NetAssist/pkg/utils" "github.com/keyvchan/NetAssist/protocol" + "github.com/rs/zerolog/log" ) func Serve() { - types := flags.GetArg(2) - log.Println("Serve:", types) + types := flags.Config.Protocol + log.Info().Msg("Serve: " + types) switch types { case "tcp": protocol.TCPServer() @@ -25,6 +24,6 @@ func Serve() { case "ip": utils.Unimplemented("ip") default: - log.Fatal("unknow protocol: ", types) + log.Error().Msg("unknow protocol: " + types) } }