diff --git a/cmd/main.go b/cmd/main.go index b3e21a6..78c5436 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -3,8 +3,10 @@ package main import ( "bufio" "crypto/x509" + "errors" "flag" "fmt" + "log" "os" "strconv" "strings" @@ -48,7 +50,7 @@ func formatSize(size int64) string { return FloatToString(float64(size)) + "B" } -func searchCommand(args []string) { +func execSearch(args []string) { searchCmd := flag.NewFlagSet("search", flag.ExitOnError) sortByFilename := searchCmd.Bool("s", false, "sort results by filename") @@ -76,7 +78,7 @@ func searchCommand(args []string) { printer.Print() } -func transferLoop(transfer *xdcc.XdccTransfer) { +func transferLoop(transfer xdcc.Transfer) { bar := pb.NewProgressBar() evts := transfer.PollEvents() @@ -104,7 +106,7 @@ func suggestUnknownAuthoritySwitch(err error) { } } -func doTransfer(transfer *xdcc.XdccTransfer) { +func doTransfer(transfer xdcc.Transfer) { err := transfer.Start() if err != nil { fmt.Println(err) @@ -116,20 +118,21 @@ func doTransfer(transfer *xdcc.XdccTransfer) { } func parseFlags(flagSet *flag.FlagSet, args []string) []string { - findFirstFlag := func(args []string) int { - for i, arg := range args { - if strings.HasPrefix(arg, "-") || strings.HasPrefix(arg, "--") { - return i - } - } - return -1 - } flagIdx := findFirstFlag(args) - if flagIdx >= 0 { - flagSet.Parse(args[flagIdx:]) - return args[:flagIdx] + if flagIdx < 0 { + return args } - return args + flagSet.Parse(args[flagIdx:]) + return args[:flagIdx] +} + +func findFirstFlag(args []string) int { + for i, arg := range args { + if strings.HasPrefix(arg, "-") || strings.HasPrefix(arg, "--") { + return i + } + } + return -1 } func loadUrlListFile(filePath string) []string { @@ -156,17 +159,17 @@ func loadUrlListFile(filePath string) []string { } func printGetUsageAndExit(flagSet *flag.FlagSet) { - fmt.Printf("usage: get url1 url2 ... [-o path] [-i file] [--allow-unknown-authority]\n\nFlag set:\n") + fmt.Printf("usage: get url1 url2 ... [-o path] [-i file] [--ssl-only]\n\nFlag set:\n") flagSet.PrintDefaults() os.Exit(0) } -func getCommand(args []string) { +func execGet(args []string) { getCmd := flag.NewFlagSet("get", flag.ExitOnError) path := getCmd.String("o", ".", "output folder of dowloaded file") inputFile := getCmd.String("i", "", "input file containing a list of urls") - skipCertificateCheck := getCmd.Bool("allow-unknown-authority", false, "skip x509 certificate check during tls connection") - noSSL := getCmd.Bool("no-ssl", false, "disable SSL.") + + sslOnly := getCmd.Bool("ssl-only", false, "force the client to use TSL connection") urlList := parseFlags(getCmd, args) @@ -180,41 +183,45 @@ func getCommand(args []string) { wg := sync.WaitGroup{} for _, urlStr := range urlList { - if strings.HasPrefix(urlStr, "irc://") { - url, err := xdcc.ParseURL(urlStr) - - if err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - - wg.Add(1) - transfer := xdcc.NewTransfer(*url, *path, !*noSSL, *skipCertificateCheck) - go func(transfer *xdcc.XdccTransfer) { - doTransfer(transfer) - wg.Done() - }(transfer) - } else { - fmt.Printf("no valid irc url %s\n", urlStr) + url, err := xdcc.ParseURL(urlStr) + if errors.Is(err, xdcc.ErrInvalidURL) { + log.Printf("no valid irc url: %s\n", urlStr) + continue } + + if err != nil { + log.Println(err.Error()) + os.Exit(1) + } + + transfer := xdcc.NewTransfer(xdcc.Config{ + File: *url, + OutPath: *path, + SSLOnly: *sslOnly, + }) + + wg.Add(1) + go func(transfer xdcc.Transfer) { + doTransfer(transfer) + wg.Done() + }(transfer) } wg.Wait() } func main() { - if len(os.Args) < 2 { - fmt.Println("one of the following subcommands is expected: [search, get]") + log.Println("one of the following subcommands is expected: [search, get]") os.Exit(1) } switch os.Args[1] { case "search": - searchCommand(os.Args[2:]) + execSearch(os.Args[2:]) case "get": - getCommand(os.Args[2:]) + execGet(os.Args[2:]) default: - fmt.Println("no such command: ", os.Args[1]) + log.Println("no such command: ", os.Args[1]) os.Exit(1) } } diff --git a/xdcc/url.go b/xdcc/url.go index cabdea0..1552ac4 100644 --- a/xdcc/url.go +++ b/xdcc/url.go @@ -29,15 +29,17 @@ func parseSlot(slotStr string) (int, error) { return strconv.Atoi(slotStr) } +var ErrInvalidURL = errors.New("invalid IRC url") + // url has the following format: irc://network/channel/bot/slot func ParseURL(url string) (*IRCFile, error) { if !strings.HasPrefix(url, "irc://") { - return nil, errors.New("not an IRC url") + return nil, ErrInvalidURL } fields := strings.Split(strings.TrimPrefix(url, "irc://"), "/") if len(fields) != ircFileURLFields { - return nil, errors.New("invalid IRC url") + return nil, ErrInvalidURL } slot, err := parseSlot(fields[3]) diff --git a/xdcc/xdcc.go b/xdcc/xdcc.go index 7eb8344..3036786 100644 --- a/xdcc/xdcc.go +++ b/xdcc/xdcc.go @@ -127,6 +127,37 @@ type TransferAbortedEvent struct { const maxConnAttempts = 5 +type Transfer interface { + Start() error + PollEvents() chan TransferEvent +} + +type retryTransfer struct { + *XdccTransfer + conf Config +} + +func (t *retryTransfer) Start() error { + t1 := newXdccTransfer(t.conf, true, false) + if err := t1.conn.Connect(); err == nil { + t.XdccTransfer = t1 + return nil + } + + t2 := newXdccTransfer(t.conf, true, true) + if err := t1.conn.Connect(); err == nil { + t.XdccTransfer = t2 + return nil + } + + t.XdccTransfer = newXdccTransfer(t.conf, false, false) + return t.XdccTransfer.conn.Connect() +} + +func (t *retryTransfer) PollEvents() chan TransferEvent { + return t.XdccTransfer.PollEvents() +} + type XdccTransfer struct { filePath string url IRCFile @@ -136,14 +167,32 @@ type XdccTransfer struct { events chan TransferEvent } -func NewTransfer(url IRCFile, filePath string, enableSSL bool, skipCertificateCheck bool) *XdccTransfer { +type Config struct { + File IRCFile + OutPath string + SSLOnly bool +} + +func NewTransfer(c Config) Transfer { + if c.SSLOnly { + return newXdccTransfer(c, true, false) + } + + return &retryTransfer{ + conf: c, + } +} + +func newXdccTransfer(c Config, enableSSL bool, skipCertificateCheck bool) *XdccTransfer { rand.Seed(time.Now().UTC().UnixNano()) nick := IRCClientUserName + strconv.Itoa(int(rand.Uint32())) + file := c.File + config := irc.NewConfig(nick) config.SSL = enableSSL - config.SSLConfig = &tls.Config{ServerName: url.Network, InsecureSkipVerify: skipCertificateCheck} - config.Server = url.Network + config.SSLConfig = &tls.Config{ServerName: file.Network, InsecureSkipVerify: skipCertificateCheck} + config.Server = file.Network config.NewNick = func(nick string) string { return nick + "" + strconv.Itoa(int(rand.Uint32())) } @@ -152,13 +201,13 @@ func NewTransfer(url IRCFile, filePath string, enableSSL bool, skipCertificateCh t := &XdccTransfer{ conn: conn, - url: url, - filePath: filePath, + url: file, + filePath: c.OutPath, started: false, connAttempts: 0, events: make(chan TransferEvent, defaultEventChanSize), } - t.setupHandlers(url.Channel, url.UserName, url.Slot) + t.setupHandlers(file.Channel, file.UserName, file.Slot) return t } @@ -211,7 +260,7 @@ func (transfer *XdccTransfer) setupHandlers(channel string, userName string, slo } if (err != nil || transfer.connAttempts >= maxConnAttempts) && !transfer.started { - transfer.notifyEvent(&TransferAbortedEvent{Error: "disconnected from server"}) + transfer.notifyEvent(&TransferAbortedEvent{Error: err.Error()}) } transfer.connAttempts++