diff --git a/README.md b/README.md index f1e7d01..9bddd9e 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,7 @@ # MASSRDNS ## Reverse DNS Lookup Tool - -This tool provides an efficient way to perform reverse DNS lookups on IP addresses, especially useful for large IP ranges. It uses concurrent workers and distributes the work among them to achieve faster results. Each request will randomly rotate betweeen the supplied DNS servers to split the load of a large CDIR across many DNS servers. +This Reverse DNS Lookup Tool is a sophisticated utility designed to perform reverse DNS lookups on large IP ranges efficiently. Built with concurrency in mind, it leverages multiple goroutines to expedite the process, making it highly scalable and performant. The tool utilizes a list of DNS servers, effectively load balancing the DNS queries across them. This not only distributes the request load but also provides redundancy; if one server fails or is slow, the tool can switch to another. Recognizing the real-world imperfections of network systems, the tool is intelligent enough to handle DNS server failures. After a certain threshold of consecutive failures, it automatically removes the faulty server from the list, ensuring that runtime is not bogged down by consistent non-performers. Furthermore, in the case of lookup failures due to network issues, the tool retries the lookup using different servers. This ensures that transient errors don't lead to missed lookups, enhancing the reliability of the results. ### Building the Project diff --git a/massrdns.go b/massrdns.go index 0938e15..007f232 100644 --- a/massrdns.go +++ b/massrdns.go @@ -15,11 +15,14 @@ import ( ) var dnsServers []string +var failureCounts = make(map[string]int) + +func loadDNSServersFromFile(filePath string) ([]string, error) { + var servers []string -func loadDNSServersFromFile(filePath string) error { file, err := os.Open(filePath) if err != nil { - return err + return nil, err } defer file.Close() @@ -27,32 +30,32 @@ func loadDNSServersFromFile(filePath string) error { for scanner.Scan() { server := scanner.Text() - // Check if the server contains a port if strings.Contains(server, ":") { host, port, err := net.SplitHostPort(server) if err != nil { - return fmt.Errorf("invalid IP:port format for %s", server) + return nil, fmt.Errorf("invalid IP:port format for %s", server) } if net.ParseIP(host) == nil { - return fmt.Errorf("invalid IP address in %s", server) + return nil, fmt.Errorf("invalid IP address in %s", server) } if _, err := strconv.Atoi(port); err != nil { - return fmt.Errorf("invalid port in %s", server) + return nil, fmt.Errorf("invalid port in %s", server) } } else { if net.ParseIP(server) == nil { - return fmt.Errorf("invalid IP address %s", server) + return nil, fmt.Errorf("invalid IP address %s", server) } - server += ":53" // Default to port 53 if not specified + server += ":53" } - dnsServers = append(dnsServers, server) + servers = append(servers, server) } - return scanner.Err() + return servers, scanner.Err() } -func reverseDNSLookup(ip string, server string) string { - ctx := context.Background() +func reverseDNSLookup(ip string, server string) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() resolver := &net.Resolver{ PreferGo: true, @@ -64,24 +67,40 @@ func reverseDNSLookup(ip string, server string) string { names, err := resolver.LookupAddr(ctx, ip) if err != nil { - return fmt.Sprintf("%s | %s | Error: %s", time.Now().Format("03:04:05 PM"), server, err) + if isNetworkError(err) { + return "", err + } + return fmt.Sprintf("%s | %s | Error: %s", time.Now().Format("03:04:05 PM"), server, err), nil } if len(names) == 0 { - return fmt.Sprintf("%s | %s | No PTR records", time.Now().Format("03:04:05 PM"), server) + return fmt.Sprintf("%s | %s | No PTR records", time.Now().Format("03:04:05 PM"), server), nil } - return fmt.Sprintf("%s | %s | %s", time.Now().Format("03:04:05 PM"), server, names[0]) + return fmt.Sprintf("%s | %s | %s", time.Now().Format("03:04:05 PM"), server, names[0]), nil } -func worker(cidr *net.IPNet, resultsChan chan string) { - for ip := make(net.IP, len(cidr.IP)); copy(ip, cidr.IP) != 0; incrementIP(ip) { - if !cidr.Contains(ip) { - break +func isNetworkError(err error) bool { + errorString := err.Error() + return strings.Contains(errorString, "timeout") || strings.Contains(errorString, "connection refused") +} + +func pickRandomServer(servers []string, triedServers map[string]bool) string { + for _, i := range rand.Perm(len(servers)) { + if !triedServers[servers[i]] { + return servers[i] } - randomServer := dnsServers[rand.Intn(len(dnsServers))] - result := reverseDNSLookup(ip.String(), randomServer) - resultsChan <- result } + return "" +} + +func removeFromList(servers []string, server string) []string { + var newList []string + for _, s := range servers { + if s != server { + newList = append(newList, s) + } + } + return newList } func splitCIDR(cidr string, parts int) ([]*net.IPNet, error) { @@ -91,6 +110,12 @@ func splitCIDR(cidr string, parts int) ([]*net.IPNet, error) { } maskSize, _ := ipNet.Mask.Size() + + maxParts := 1 << uint(32-maskSize) + if parts > maxParts { + parts = maxParts + } + newMaskSize := maskSize for ; (1 << uint(newMaskSize-maskSize)) < parts; newMaskSize++ { if newMaskSize > 32 { @@ -110,6 +135,48 @@ func splitCIDR(cidr string, parts int) ([]*net.IPNet, error) { return subnets, nil } +func worker(cidr *net.IPNet, resultsChan chan string) { + for ip := make(net.IP, len(cidr.IP)); copy(ip, cidr.IP) != 0; incrementIP(ip) { + if !cidr.Contains(ip) { + break + } + + triedServers := make(map[string]bool) + retries := 10 + success := false + + for retries > 0 { + randomServer := pickRandomServer(dnsServers, triedServers) + if randomServer == "" { + break + } + + result, err := reverseDNSLookup(ip.String(), randomServer) + + // Check for network errors + if err != nil && isNetworkError(err) { + failureCounts[randomServer]++ + if failureCounts[randomServer] > 10 { + dnsServers = removeFromList(dnsServers, randomServer) + delete(failureCounts, randomServer) + } + + triedServers[randomServer] = true + retries-- + continue + } else if err == nil { + resultsChan <- result + success = true + break + } + } + + if !success { + resultsChan <- fmt.Sprintf("%s | %s | Max retries reached", time.Now().Format("03:04:05 PM"), ip) + } + } +} + func main() { var cidr string var concurrency int @@ -125,7 +192,9 @@ func main() { os.Exit(1) } - if err := loadDNSServersFromFile(dnsFile); err != nil { + var err error + dnsServers, err = loadDNSServersFromFile(dnsFile) + if err != nil { fmt.Printf("Error reading DNS servers from file %s: %s\n", dnsFile, err) os.Exit(1) } @@ -143,21 +212,24 @@ func main() { os.Exit(1) } - // Create a channel to feed CIDR blocks to workers + if len(subnets) < concurrency { + concurrency = len(subnets) // Limit concurrency to number of subnets + } + cidrChan := make(chan *net.IPNet, len(subnets)) for _, subnet := range subnets { cidrChan <- subnet } - close(cidrChan) // Close it, so workers can detect when there's no more work + close(cidrChan) - resultsChan := make(chan string, concurrency*2) // Increased buffer size for results + resultsChan := make(chan string, concurrency*2) var wg sync.WaitGroup for i := 0; i < concurrency; i++ { wg.Add(1) go func() { defer wg.Done() - for subnet := range cidrChan { // Keep working until there's no more work + for subnet := range cidrChan { worker(subnet, resultsChan) } }()