From 9bed4845b02a5efdd24d9fda8adb3758b2ea5dc5 Mon Sep 17 00:00:00 2001 From: rekby Date: Sat, 25 May 2019 05:09:38 +0300 Subject: [PATCH] Add support of external dns for check domains Close #68 --- cmd/a_main-packr.go | 2 +- cmd/static/default-config.toml | 11 + go.mod | 2 + go.sum | 2 + internal/dns/m_dns_client_mock_test.go | 242 +++++++++++ internal/dns/parallel.go | 68 ++++ internal/dns/parallel_test.go | 95 +++++ internal/dns/resolver.go | 172 ++++++++ internal/dns/resolver_interface_mock_test.go | 243 +++++++++++ internal/dns/resolver_test.go | 405 +++++++++++++++++++ internal/domain_checker/config.go | 44 +- internal/domain_checker/ip_list.go | 13 +- internal/domain_checker/ip_list_test.go | 11 + vendor/modules.txt | 10 + 14 files changed, 1314 insertions(+), 6 deletions(-) create mode 100644 internal/dns/m_dns_client_mock_test.go create mode 100644 internal/dns/parallel.go create mode 100644 internal/dns/parallel_test.go create mode 100644 internal/dns/resolver.go create mode 100644 internal/dns/resolver_interface_mock_test.go create mode 100644 internal/dns/resolver_test.go diff --git a/cmd/a_main-packr.go b/cmd/a_main-packr.go index cb95a54b..20d56206 100644 --- a/cmd/a_main-packr.go +++ b/cmd/a_main-packr.go @@ -7,5 +7,5 @@ import "github.com/gobuffalo/packr" // You can use the "packr clean" command to clean up this, // and any other packr generated files. func init() { - packr.PackJSONBytes("static", "default-config.toml", "\"W0dlbmVyYWxdCgojIFNlY29uZHMgZm9yIGlzc3VlIGV2ZXJ5IGNlcnRpZmljYXRlLiBDYW5jZWwgaXNzdWUgYW5kIHJldHVybiBlcnJvciBpZiB0aW1lb3V0LgpJc3N1ZVRpbWVvdXQgPSAzMDAKCiMgUGF0aCB0byBkaXIsIHdoaWNoIHdpbGwgc3RvcmUgc3RhdGUgYW5kIGNlcnRpZmljYXRlcwpTdG9yYWdlRGlyID0gInN0b3JhZ2UiCgojIFN0b3JlIC5qc29uIGluZm8gd2l0aCBjZXJ0aWZpY2F0ZSBtZXRhZGF0YSBuZWFyIGNlcnRpZmljYXRlLgpTdG9yZUpTT05NZXRhZGF0YSA9IHRydWUKCiMgRGlyZWN0b3J5IHVybCBvZiBhY21lIHNlcnZlci4KI1Rlc3Qgc2VydmVyOiBodHRwczovL2FjbWUtc3RhZ2luZy12MDIuYXBpLmxldHNlbmNyeXB0Lm9yZy9kaXJlY3RvcnkKQWNtZVNlcnZlciA9ICJodHRwczovL2FjbWUtdjAxLmFwaS5sZXRzZW5jcnlwdC5vcmcvZGlyZWN0b3J5IgoKW0xvZ10KRW5hYmxlTG9nVG9GaWxlID0gdHJ1ZQpFbmFibGVMb2dUb1N0ZEVyciA9IHRydWUKCiMgdmVyYm9zZSBsZXZlbCBvZiBsb2csIG9uZSBvZjogZGVidWcsIGluZm8sIHdhcm5pbmcsIGVycm9yLCBmYXRhbApMb2dMZXZlbCA9ICJpbmZvIgoKIyBFbmFibGUgc2VsZiBsb2cgcm90YXRpbmcKRW5hYmxlUm90YXRlID0gdHJ1ZQoKIyBFbmFibGUgZGV2ZWxvcGVyIG1vZGU6IG1vcmUgc3RhY2t0cmFjZXMgYW5kIHBhbmljIChzdG9wIHByb2dyYW0pIG9uIHNvbWUgaW50ZXJuYWwgZXJyb3JzLgpEZXZlbG9wZXJNb2RlID0gZmFsc2UKCiMgUGF0aCB0byBsb2cgZmlsZQpGaWxlID0gImxldHMtcHJveHkubG9nIgoKIyBSb3RhdGUgbG9nIGlmIGN1cnJlbnQgZmlsZSBzaXplIG1vcmUgdGhhbiBYIE1CClJvdGF0ZUJ5U2l6ZU1CID0gMTAwCgojIENvbXByZXNzIG9sZCBsb2cgd2l0aCBnemlwIGFmdGVyIHJvdGF0ZQpDb21wcmVzc1JvdGF0ZWQgPSBmYWxzZQoKIyBEZWxldGUgb2xkIGJhY2t1cHMgYWZ0ZXIgWCBkYXlzLiAwIGZvciBkaXNhYmxlLgpNYXhEYXlzID0gMTAKCiMgRGVsZXRlIG9sZCBiYWNrdXBzIGlmIG9sZCBmaWxlIG51bWJlciBtb3JlIHRoZW4gWC4gMCBmb3IgZGlzYWJsZS4KTWF4Q291bnQgPSAxMAoKW1Byb3h5XQoKIyBEZWZhdWx0IHJ1bGUgb2Ygc2VsZWN0IGRlc3RpbmF0aW9uIGFkZHJlc3MuCiNJdCBjYW4gYmU6IElQICh3aXRoIGRlZmF1bHQgcG9ydCA4MCksIDpQb3J0IChkZWZhdWx0IC0gc2FtZSBJUCBhcyByZWNlaXZlIGNvbm5lY3Rpb24pLCBJUHY0OlBvcnQgb3IgW0lQdjZdOlBvcnQKRGVmYXVsdFRhcmdldCA9ICI6ODAiCgojIEFycmF5IG9mICctJyBzZXBhcmF0ZWQgcGFpcnMgb3IgSVA6UG9ydC4gRm9yIGV4YW1wbGU6CiMgWwojICAgIjEuMi4zLjQ6NDQzLTIuMi4yLjI6MTIzNCIsCiMgICAiMy4zLjMuMzozMzMtWzo6MV06OTQiCiMgIl0KIyBNZWFuOiBjb25uZWN0aW9ucywgYWNjZXB0ZWQgb24gMS4yLjMuNDo0NDMgc2VuZCB0byBzZXJ2ZXIgMi4yLjIuMjoxMjM0CiMgYW5kIGNvbm5lY3Rpb25zIGFjY2VwdGVkIG9uIDMuMy4zLjM6MzMzIHNlbmQgdG8gaXB2NiA6OjEgcG9ydCA5NApUYXJnZXRNYXAgPSBbXQoKIyBBcnJheSBvZiBjb2xvbiBzZXBhcmF0ZWQgSGVhZGVyTmFtZTpIZWFkZXJWYWx1ZSBmb3IgYWRkIHRvIHJlcXVlc3QgZm9yIGJhY2tlbmQuIHt7VmFsdWV9fSBpcyBzcGVjaWFsIGZvcm1zLCB3aGljaCBjYW4KIyBpbnRlcm5hbGx5IHBhcnNpbmcuIE5vdyBpdCBzdXBwb3J0IG9ubHkgc3BlY2lhbCB2YWx1ZXM6CiMge3tDT05ORUNUSU9OX0lEfX0gLSBJZCBvZiBhY2NlcHRlZCBjb25uZWN0aW9uLCBnZW5lcmF0ZWQgYnkgbGV0cy1wcm94eQojIHt7SFRUUF9QUk9UT319IC0gc2V0IHRvIGh0dHAvaHR0cHMgZGVwZW5kZW5jZSBpbmNvbWluZyBjb25uZWN0aW9ucyBoYW5kbGVkCiMge3tTT1VSQ0VfSVB9fSAtIFJlbW90ZSBJUCBvZiBpbmNvbWluZyBjb25uZWN0aW9uCiMge1NPVVJDRV9QT1JUfX0gLSBSZW1vdGUgcG9ydCBvZiBpbmNvbWluZyBjb25uZWN0aW9uCiMge3tTT1VSQ0VfSVB9fTp7e1NPVVJDRV9QT1JUfX0gLSBSZW1vdGUgSVA6UG9ydCBvZiBpbmNvbWluZyBjb25uZWN0aW9uLgojIE5vdyBpdCBhY2NlcHRlZCBvbmx5IHRoaXMgc3BlY2lhbCB2YWx1ZXMsIHdoaWNoIG11c3QgYmUgZXhheGx0eSBlcXVhbCB0byBleGFtcGxlcy4gQWxsIG90aGVyIHZhbHVlcyBzZW5kIGFzIGlzLgojIEJ1dCBpdCBjYW4gY2hhbmdlIGFuZCBleHRlbmQgaW4gZnV0dXJlLiBEb2Vzbid0IHVzZSB7ey4uLn19IGFzIG93biB2YWx1ZXMuCiMgRXhhbXBsZToKIyBbIklQOnt7U09VUkNFX0lQfX0iLCAiUHJveHk6bGV0cy1wcm94eSIsICJQcm90b2NvbDp7e0hUVFBfUFJPVE99fSIgXQpIZWFkZXJzID0gW10KCltDaGVja0RvbWFpbnNdCgojIEFsbG93IGRvbWFpbiBpZiBpdCByZXNvbHZlciBmb3Igb25lIG9mIHB1YmxpYyBJUHMgb2YgdGhpcyBzZXJ2ZXIuCklQU2VsZiA9IHRydWUKCiMgQWxsb3cgZG9tYWluIGlmIGl0IHJlc29sdmVyIGZvciBvbmUgb2YgdGhlIGlwcy4KSVBXaGl0ZUxpc3QgPSAiIgoKIyBSZWdleHAgaW4gZ29sYW5nIHN5bnRheCBvZiBibGFja2xpc3RlZCBkb21haW4gZm9yIGlzc3VlIGNlcnRpZmljYXRlLgojVGhpcyBsaXN0IG92ZXJyaWRlZCBieSB3aGl0ZWxpc3QuCkJsYWNrTGlzdCA9ICIiCgojIFJlZ2V4cCBpbiBnb2xhbmcgc3ludGF4IG9mIHdoaXRlbGlzdCBkb21haW5zIGZvciBpc3N1ZSBjZXJ0aWZpY2F0ZS4KI1doaXRlbGlzdCBuZWVkIGZvciBhbGxvdyBwYXJ0IG9mIGRvbWFpbnMsIHdoaWNoIGV4Y2x1ZGVkIGJ5IGJsYWNrbGlzdC4KIwpXaGl0ZUxpc3QgPSAiIgoKW0xpc3Rlbl0KCiMgQmluZCBhZGRyZXNzZXMgZm9yIFRMUyBsaXN0ZW5lcnMKVExTQWRkcmVzc2VzID0gWyI6NDQzIl0KCiMgQmluZCBhZGRyZXNzZXMgd2l0aG91dCBUTFMgc2VjdXJlIChmb3IgSFRUUCByZXZlcnNlIHByb3h5IGFuZCBodHRwLTAxIHZhbGlkYXRpb24gd2l0aG91dCByZWRpcmVjdCB0byBodHRwcykKVENQQWRkcmVzc2VzID0gW10K\"") + packr.PackJSONBytes("static", "default-config.toml", "\"W0dlbmVyYWxdCgojIFNlY29uZHMgZm9yIGlzc3VlIGV2ZXJ5IGNlcnRpZmljYXRlLiBDYW5jZWwgaXNzdWUgYW5kIHJldHVybiBlcnJvciBpZiB0aW1lb3V0LgpJc3N1ZVRpbWVvdXQgPSAzMDAKCiMgUGF0aCB0byBkaXIsIHdoaWNoIHdpbGwgc3RvcmUgc3RhdGUgYW5kIGNlcnRpZmljYXRlcwpTdG9yYWdlRGlyID0gInN0b3JhZ2UiCgojIFN0b3JlIC5qc29uIGluZm8gd2l0aCBjZXJ0aWZpY2F0ZSBtZXRhZGF0YSBuZWFyIGNlcnRpZmljYXRlLgpTdG9yZUpTT05NZXRhZGF0YSA9IHRydWUKCiMgRGlyZWN0b3J5IHVybCBvZiBhY21lIHNlcnZlci4KI1Rlc3Qgc2VydmVyOiBodHRwczovL2FjbWUtc3RhZ2luZy12MDIuYXBpLmxldHNlbmNyeXB0Lm9yZy9kaXJlY3RvcnkKQWNtZVNlcnZlciA9ICJodHRwczovL2FjbWUtdjAxLmFwaS5sZXRzZW5jcnlwdC5vcmcvZGlyZWN0b3J5IgoKW0xvZ10KRW5hYmxlTG9nVG9GaWxlID0gdHJ1ZQpFbmFibGVMb2dUb1N0ZEVyciA9IHRydWUKCiMgdmVyYm9zZSBsZXZlbCBvZiBsb2csIG9uZSBvZjogZGVidWcsIGluZm8sIHdhcm5pbmcsIGVycm9yLCBmYXRhbApMb2dMZXZlbCA9ICJpbmZvIgoKIyBFbmFibGUgc2VsZiBsb2cgcm90YXRpbmcKRW5hYmxlUm90YXRlID0gdHJ1ZQoKIyBFbmFibGUgZGV2ZWxvcGVyIG1vZGU6IG1vcmUgc3RhY2t0cmFjZXMgYW5kIHBhbmljIChzdG9wIHByb2dyYW0pIG9uIHNvbWUgaW50ZXJuYWwgZXJyb3JzLgpEZXZlbG9wZXJNb2RlID0gZmFsc2UKCiMgUGF0aCB0byBsb2cgZmlsZQpGaWxlID0gImxldHMtcHJveHkubG9nIgoKIyBSb3RhdGUgbG9nIGlmIGN1cnJlbnQgZmlsZSBzaXplIG1vcmUgdGhhbiBYIE1CClJvdGF0ZUJ5U2l6ZU1CID0gMTAwCgojIENvbXByZXNzIG9sZCBsb2cgd2l0aCBnemlwIGFmdGVyIHJvdGF0ZQpDb21wcmVzc1JvdGF0ZWQgPSBmYWxzZQoKIyBEZWxldGUgb2xkIGJhY2t1cHMgYWZ0ZXIgWCBkYXlzLiAwIGZvciBkaXNhYmxlLgpNYXhEYXlzID0gMTAKCiMgRGVsZXRlIG9sZCBiYWNrdXBzIGlmIG9sZCBmaWxlIG51bWJlciBtb3JlIHRoZW4gWC4gMCBmb3IgZGlzYWJsZS4KTWF4Q291bnQgPSAxMAoKW1Byb3h5XQoKIyBEZWZhdWx0IHJ1bGUgb2Ygc2VsZWN0IGRlc3RpbmF0aW9uIGFkZHJlc3MuCiNJdCBjYW4gYmU6IElQICh3aXRoIGRlZmF1bHQgcG9ydCA4MCksIDpQb3J0IChkZWZhdWx0IC0gc2FtZSBJUCBhcyByZWNlaXZlIGNvbm5lY3Rpb24pLCBJUHY0OlBvcnQgb3IgW0lQdjZdOlBvcnQKRGVmYXVsdFRhcmdldCA9ICI6ODAiCgojIEFycmF5IG9mICctJyBzZXBhcmF0ZWQgcGFpcnMgb3IgSVA6UG9ydC4gRm9yIGV4YW1wbGU6CiMgWwojICAgIjEuMi4zLjQ6NDQzLTIuMi4yLjI6MTIzNCIsCiMgICAiMy4zLjMuMzozMzMtWzo6MV06OTQiCiMgIl0KIyBNZWFuOiBjb25uZWN0aW9ucywgYWNjZXB0ZWQgb24gMS4yLjMuNDo0NDMgc2VuZCB0byBzZXJ2ZXIgMi4yLjIuMjoxMjM0CiMgYW5kIGNvbm5lY3Rpb25zIGFjY2VwdGVkIG9uIDMuMy4zLjM6MzMzIHNlbmQgdG8gaXB2NiA6OjEgcG9ydCA5NApUYXJnZXRNYXAgPSBbXQoKIyBBcnJheSBvZiBjb2xvbiBzZXBhcmF0ZWQgSGVhZGVyTmFtZTpIZWFkZXJWYWx1ZSBmb3IgYWRkIHRvIHJlcXVlc3QgZm9yIGJhY2tlbmQuIHt7VmFsdWV9fSBpcyBzcGVjaWFsIGZvcm1zLCB3aGljaCBjYW4KIyBpbnRlcm5hbGx5IHBhcnNpbmcuIE5vdyBpdCBzdXBwb3J0IG9ubHkgc3BlY2lhbCB2YWx1ZXM6CiMge3tDT05ORUNUSU9OX0lEfX0gLSBJZCBvZiBhY2NlcHRlZCBjb25uZWN0aW9uLCBnZW5lcmF0ZWQgYnkgbGV0cy1wcm94eQojIHt7SFRUUF9QUk9UT319IC0gc2V0IHRvIGh0dHAvaHR0cHMgZGVwZW5kZW5jZSBpbmNvbWluZyBjb25uZWN0aW9ucyBoYW5kbGVkCiMge3tTT1VSQ0VfSVB9fSAtIFJlbW90ZSBJUCBvZiBpbmNvbWluZyBjb25uZWN0aW9uCiMge1NPVVJDRV9QT1JUfX0gLSBSZW1vdGUgcG9ydCBvZiBpbmNvbWluZyBjb25uZWN0aW9uCiMge3tTT1VSQ0VfSVB9fTp7e1NPVVJDRV9QT1JUfX0gLSBSZW1vdGUgSVA6UG9ydCBvZiBpbmNvbWluZyBjb25uZWN0aW9uLgojIE5vdyBpdCBhY2NlcHRlZCBvbmx5IHRoaXMgc3BlY2lhbCB2YWx1ZXMsIHdoaWNoIG11c3QgYmUgZXhheGx0eSBlcXVhbCB0byBleGFtcGxlcy4gQWxsIG90aGVyIHZhbHVlcyBzZW5kIGFzIGlzLgojIEJ1dCBpdCBjYW4gY2hhbmdlIGFuZCBleHRlbmQgaW4gZnV0dXJlLiBEb2Vzbid0IHVzZSB7ey4uLn19IGFzIG93biB2YWx1ZXMuCiMgRXhhbXBsZToKIyBbIklQOnt7U09VUkNFX0lQfX0iLCAiUHJveHk6bGV0cy1wcm94eSIsICJQcm90b2NvbDp7e0hUVFBfUFJPVE99fSIgXQpIZWFkZXJzID0gW10KCltDaGVja0RvbWFpbnNdCgojIEFsbG93IGRvbWFpbiBpZiBpdCByZXNvbHZlciBmb3Igb25lIG9mIHB1YmxpYyBJUHMgb2YgdGhpcyBzZXJ2ZXIuCklQU2VsZiA9IHRydWUKCiMgQWxsb3cgZG9tYWluIGlmIGl0IHJlc29sdmVyIGZvciBvbmUgb2YgdGhlIGlwcy4KSVBXaGl0ZUxpc3QgPSAiIgoKIyBSZWdleHAgaW4gZ29sYW5nIHN5bnRheCBvZiBibGFja2xpc3RlZCBkb21haW4gZm9yIGlzc3VlIGNlcnRpZmljYXRlLgojVGhpcyBsaXN0IG92ZXJyaWRlZCBieSB3aGl0ZWxpc3QuCkJsYWNrTGlzdCA9ICIiCgojIFJlZ2V4cCBpbiBnb2xhbmcgc3ludGF4IG9mIHdoaXRlbGlzdCBkb21haW5zIGZvciBpc3N1ZSBjZXJ0aWZpY2F0ZS4KI1doaXRlbGlzdCBuZWVkIGZvciBhbGxvdyBwYXJ0IG9mIGRvbWFpbnMsIHdoaWNoIGV4Y2x1ZGVkIGJ5IGJsYWNrbGlzdC4KIwpXaGl0ZUxpc3QgPSAiIgoKIyBDb21tYSBzZXBhcmF0ZWQgZG5zIHNlcnZlciwgdXNlZCBmb3IgcmVzb2x2ZSBpcDpwb3J0IGFkZHJlc3Mgb2YgZG9tYWlucyB3aGlsZSBjaGVjayBpdC4KIyBpZiBlbXB0eSAtIHVzZSBzeXN0ZW0gZG5zIHJlc29sdmVyICh1c3VhbGx5IGluY2x1ZGUgaG9zdHMgZmlsZSwgY2FjaGUsIGV0YykKIyBpZiBzZXQgLSB1c2UgZGlyZWN0IGRucyBxdWVyaWVzIGZvciBzZXJ2ZXJzLCB3aXRob3V0IHNlbGYgY2FjaGUuCiMgaWYgc2V0IG1vcmUsIHRoYW4gb25lIGRucyBzZXJ2ZXIgLSBzZW5kIHF1ZXJpZXMgaW4gcGFyYWxsZWwgdG8gYWxsIHNlcnZlcnMuCiMgZXJyb3IgcmVzdWx0cyBmcm9tIHBhcnQgb2Ygc2VydmVycyAtIGlnbm9yZS4gTmVlZCBtaW5pbXVtIG9uZSBhbnN3ZXIuCiMgaWYgZGlmZmVyZW50IGRucyBzZXJ2ZXJzIHJldHVybiBkaWZmZXJlbnQgaXAgYWRkcmVzc2VzIC0gYWxsIG9mIHRoZW0gdXNlIGZvciBjaGVjawojIEV4YW1wbGU6ICI4LjguOC44OjUzLDEuMS4xLjE6NTMsNzcuODguOC44OjUzLFsyYTAyOjZiODo6ZmVlZDowZmZdOjUzLFsyMDAxOjQ4NjA6NDg2MDo6ODg4OF06NTMiClJlc29sdmVyID0gIiIKCgoKW0xpc3Rlbl0KCiMgQmluZCBhZGRyZXNzZXMgZm9yIFRMUyBsaXN0ZW5lcnMKVExTQWRkcmVzc2VzID0gWyI6NDQzIl0KCiMgQmluZCBhZGRyZXNzZXMgd2l0aG91dCBUTFMgc2VjdXJlIChmb3IgSFRUUCByZXZlcnNlIHByb3h5IGFuZCBodHRwLTAxIHZhbGlkYXRpb24gd2l0aG91dCByZWRpcmVjdCB0byBodHRwcykKVENQQWRkcmVzc2VzID0gW10K\"") } diff --git a/cmd/static/default-config.toml b/cmd/static/default-config.toml index fc3bf093..9d915a9e 100644 --- a/cmd/static/default-config.toml +++ b/cmd/static/default-config.toml @@ -86,6 +86,17 @@ BlackList = "" # WhiteList = "" +# Comma separated dns server, used for resolve ip:port address of domains while check it. +# if empty - use system dns resolver (usually include hosts file, cache, etc) +# if set - use direct dns queries for servers, without self cache. +# if set more, than one dns server - send queries in parallel to all servers. +# error results from part of servers - ignore. Need minimum one answer. +# if different dns servers return different ip addresses - all of them use for check +# Example: "8.8.8.8:53,1.1.1.1:53,77.88.8.8:53,[2a02:6b8::feed:0ff]:53,[2001:4860:4860::8888]:53" +Resolver = "" + + + [Listen] # Bind addresses for TLS listeners diff --git a/go.mod b/go.mod index 67034d76..458805ad 100644 --- a/go.mod +++ b/go.mod @@ -9,8 +9,10 @@ require ( github.com/golangci/golangci-lint v1.16.0 // indirect github.com/kardianos/minwinsvc v0.0.0-20151122163309-cad6b2b879b0 github.com/maxatome/go-testdeep v1.0.8 + github.com/miekg/dns v1.1.12 github.com/mitchellh/gox v1.0.1 // indirect github.com/pelletier/go-toml v1.3.0 // indirect + github.com/pkg/errors v0.8.1 github.com/rekby/zapcontext v0.0.3 github.com/satori/go.uuid v1.2.0 go.uber.org/zap v1.9.1 diff --git a/go.sum b/go.sum index 64c4a301..e568dff7 100644 --- a/go.sum +++ b/go.sum @@ -152,6 +152,8 @@ github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpe github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/maxatome/go-testdeep v1.0.8 h1:OORprsdGMEDRo1diHCQzxCAaZFWKlHDrBnPttC4wL8g= github.com/maxatome/go-testdeep v1.0.8/go.mod h1:Vcp0RXXOMhUTw2S2HRmUIqHQpG2Oxz+HM/WSWK7yXto= +github.com/miekg/dns v1.1.12 h1:WMhc1ik4LNkTg8U9l3hI1LvxKmIL+f1+WV/SZtCbDDA= +github.com/miekg/dns v1.1.12/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/go-homedir v1.0.0 h1:vKb8ShqSby24Yrqr/yDYkuFz8d0WUjys40rvnGC8aR0= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-ps v0.0.0-20170309133038-4fdf99ab2936/go.mod h1:r1VsdOzOPt1ZSrGZWFoNhsAedKnEd6r9Np1+5blZCWk= diff --git a/internal/dns/m_dns_client_mock_test.go b/internal/dns/m_dns_client_mock_test.go new file mode 100644 index 00000000..86b84fe7 --- /dev/null +++ b/internal/dns/m_dns_client_mock_test.go @@ -0,0 +1,242 @@ +package dns + +// DO NOT EDIT! +// The code below was generated with http://github.com/gojuno/minimock (dev) + +//go:generate minimock -i github.com/rekby/lets-proxy2/internal/dns.mDNSClient -o ./m_dns_client_mock_test.go + +import ( + "sync/atomic" + "time" + + mdns "github.com/miekg/dns" + + "github.com/gojuno/minimock" +) + +// MDNSClientMock implements mDNSClient +type MDNSClientMock struct { + t minimock.Tester + + funcExchange func(msg *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error) + afterExchangeCounter uint64 + beforeExchangeCounter uint64 + ExchangeMock mMDNSClientMockExchange +} + +// NewMDNSClientMock returns a mock for mDNSClient +func NewMDNSClientMock(t minimock.Tester) *MDNSClientMock { + m := &MDNSClientMock{t: t} + if controller, ok := t.(minimock.MockController); ok { + controller.RegisterMocker(m) + } + m.ExchangeMock = mMDNSClientMockExchange{mock: m} + + return m +} + +type mMDNSClientMockExchange struct { + mock *MDNSClientMock + defaultExpectation *MDNSClientMockExchangeExpectation + expectations []*MDNSClientMockExchangeExpectation +} + +// MDNSClientMockExchangeExpectation specifies expectation struct of the mDNSClient.Exchange +type MDNSClientMockExchangeExpectation struct { + mock *MDNSClientMock + params *MDNSClientMockExchangeParams + results *MDNSClientMockExchangeResults + Counter uint64 +} + +// MDNSClientMockExchangeParams contains parameters of the mDNSClient.Exchange +type MDNSClientMockExchangeParams struct { + msg *mdns.Msg + address string +} + +// MDNSClientMockExchangeResults contains results of the mDNSClient.Exchange +type MDNSClientMockExchangeResults struct { + r *mdns.Msg + rtt time.Duration + err error +} + +// Expect sets up expected params for mDNSClient.Exchange +func (m *mMDNSClientMockExchange) Expect(msg *mdns.Msg, address string) *mMDNSClientMockExchange { + if m.mock.funcExchange != nil { + m.mock.t.Fatalf("MDNSClientMock.Exchange mock is already set by Set") + } + + if m.defaultExpectation == nil { + m.defaultExpectation = &MDNSClientMockExchangeExpectation{} + } + + m.defaultExpectation.params = &MDNSClientMockExchangeParams{msg, address} + for _, e := range m.expectations { + if minimock.Equal(e.params, m.defaultExpectation.params) { + m.mock.t.Fatalf("Expectation set by When has same params: %#v", *m.defaultExpectation.params) + } + } + + return m +} + +// Return sets up results that will be returned by mDNSClient.Exchange +func (m *mMDNSClientMockExchange) Return(r *mdns.Msg, rtt time.Duration, err error) *MDNSClientMock { + if m.mock.funcExchange != nil { + m.mock.t.Fatalf("MDNSClientMock.Exchange mock is already set by Set") + } + + if m.defaultExpectation == nil { + m.defaultExpectation = &MDNSClientMockExchangeExpectation{mock: m.mock} + } + m.defaultExpectation.results = &MDNSClientMockExchangeResults{r, rtt, err} + return m.mock +} + +//Set uses given function f to mock the mDNSClient.Exchange method +func (m *mMDNSClientMockExchange) Set(f func(msg *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error)) *MDNSClientMock { + if m.defaultExpectation != nil { + m.mock.t.Fatalf("Default expectation is already set for the mDNSClient.Exchange method") + } + + if len(m.expectations) > 0 { + m.mock.t.Fatalf("Some expectations are already set for the mDNSClient.Exchange method") + } + + m.mock.funcExchange = f + return m.mock +} + +// When sets expectation for the mDNSClient.Exchange which will trigger the result defined by the following +// Then helper +func (m *mMDNSClientMockExchange) When(msg *mdns.Msg, address string) *MDNSClientMockExchangeExpectation { + if m.mock.funcExchange != nil { + m.mock.t.Fatalf("MDNSClientMock.Exchange mock is already set by Set") + } + + expectation := &MDNSClientMockExchangeExpectation{ + mock: m.mock, + params: &MDNSClientMockExchangeParams{msg, address}, + } + m.expectations = append(m.expectations, expectation) + return expectation +} + +// Then sets up mDNSClient.Exchange return parameters for the expectation previously defined by the When method +func (e *MDNSClientMockExchangeExpectation) Then(r *mdns.Msg, rtt time.Duration, err error) *MDNSClientMock { + e.results = &MDNSClientMockExchangeResults{r, rtt, err} + return e.mock +} + +// Exchange implements mDNSClient +func (m *MDNSClientMock) Exchange(msg *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error) { + atomic.AddUint64(&m.beforeExchangeCounter, 1) + defer atomic.AddUint64(&m.afterExchangeCounter, 1) + + for _, e := range m.ExchangeMock.expectations { + if minimock.Equal(*e.params, MDNSClientMockExchangeParams{msg, address}) { + atomic.AddUint64(&e.Counter, 1) + return e.results.r, e.results.rtt, e.results.err + } + } + + if m.ExchangeMock.defaultExpectation != nil { + atomic.AddUint64(&m.ExchangeMock.defaultExpectation.Counter, 1) + want := m.ExchangeMock.defaultExpectation.params + got := MDNSClientMockExchangeParams{msg, address} + if want != nil && !minimock.Equal(*want, got) { + m.t.Errorf("MDNSClientMock.Exchange got unexpected parameters, want: %#v, got: %#v%s\n", *want, got, minimock.Diff(*want, got)) + } + + results := m.ExchangeMock.defaultExpectation.results + if results == nil { + m.t.Fatal("No results are set for the MDNSClientMock.Exchange") + } + return (*results).r, (*results).rtt, (*results).err + } + if m.funcExchange != nil { + return m.funcExchange(msg, address) + } + m.t.Fatalf("Unexpected call to MDNSClientMock.Exchange. %v %v", msg, address) + return +} + +// ExchangeAfterCounter returns a count of finished MDNSClientMock.Exchange invocations +func (m *MDNSClientMock) ExchangeAfterCounter() uint64 { + return atomic.LoadUint64(&m.afterExchangeCounter) +} + +// ExchangeBeforeCounter returns a count of MDNSClientMock.Exchange invocations +func (m *MDNSClientMock) ExchangeBeforeCounter() uint64 { + return atomic.LoadUint64(&m.beforeExchangeCounter) +} + +// MinimockExchangeDone returns true if the count of the Exchange invocations corresponds +// the number of defined expectations +func (m *MDNSClientMock) MinimockExchangeDone() bool { + for _, e := range m.ExchangeMock.expectations { + if atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + // if default expectation was set then invocations count should be greater than zero + if m.ExchangeMock.defaultExpectation != nil && atomic.LoadUint64(&m.afterExchangeCounter) < 1 { + return false + } + // if func was set then invocations count should be greater than zero + if m.funcExchange != nil && atomic.LoadUint64(&m.afterExchangeCounter) < 1 { + return false + } + return true +} + +// MinimockExchangeInspect logs each unmet expectation +func (m *MDNSClientMock) MinimockExchangeInspect() { + for _, e := range m.ExchangeMock.expectations { + if atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to MDNSClientMock.Exchange with params: %#v", *e.params) + } + } + + // if default expectation was set then invocations count should be greater than zero + if m.ExchangeMock.defaultExpectation != nil && atomic.LoadUint64(&m.afterExchangeCounter) < 1 { + m.t.Errorf("Expected call to MDNSClientMock.Exchange with params: %#v", *m.ExchangeMock.defaultExpectation.params) + } + // if func was set then invocations count should be greater than zero + if m.funcExchange != nil && atomic.LoadUint64(&m.afterExchangeCounter) < 1 { + m.t.Error("Expected call to MDNSClientMock.Exchange") + } +} + +// MinimockFinish checks that all mocked methods have been called the expected number of times +func (m *MDNSClientMock) MinimockFinish() { + if !m.minimockDone() { + m.MinimockExchangeInspect() + m.t.FailNow() + } +} + +// MinimockWait waits for all mocked methods to be called the expected number of times +func (m *MDNSClientMock) MinimockWait(timeout time.Duration) { + timeoutCh := time.After(timeout) + for { + if m.minimockDone() { + return + } + select { + case <-timeoutCh: + m.MinimockFinish() + return + case <-time.After(10 * time.Millisecond): + } + } +} + +func (m *MDNSClientMock) minimockDone() bool { + done := true + return done && + m.MinimockExchangeDone() +} diff --git a/internal/dns/parallel.go b/internal/dns/parallel.go new file mode 100644 index 00000000..6a53972d --- /dev/null +++ b/internal/dns/parallel.go @@ -0,0 +1,68 @@ +package dns + +import ( + "context" + "net" + "sync" +) + +type ResolverInterface interface { + // LookupIPAddr return ip addresses of domain. It MUST finish work when context canceled + LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) +} + +type Parallel []ResolverInterface + +// NewParallel return parallel resolver +func NewParallel(resolvers ...ResolverInterface) Parallel { + state := make(Parallel, len(resolvers)) + copy(state, resolvers) + return state +} + +// LookupIPAddr return ip addresses of host, used underly resolvers in parallel +// If any of resolvers return ips - return sum array of the ips (may duplicated) +// If all resolvers return error - return any of they errors +func (p Parallel) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) { + switch len(p) { + case 0: + return nil, nil + case 1: + return p[0].LookupIPAddr(ctx, host) + default: + // pass + } + + var ips = make([][]net.IPAddr, len(p)) + var errs = make([]error, len(p)) + + var wg sync.WaitGroup + wg.Add(len(p)) + for i := range p { + go func(i int) { + ips[i], errs[i] = p[i].LookupIPAddr(ctx, host) + wg.Done() + }(i) + } + + wg.Wait() + + resLen := 0 + var err error + for i := range ips { + resLen += len(ips[i]) + if errs[i] != nil { + err = errs[i] + } + } + + if resLen == 0 { + return nil, err + } + + resIps := make([]net.IPAddr, 0, resLen) + for i := range ips { + resIps = append(resIps, ips[i]...) + } + return resIps, nil +} diff --git a/internal/dns/parallel_test.go b/internal/dns/parallel_test.go new file mode 100644 index 00000000..1e392112 --- /dev/null +++ b/internal/dns/parallel_test.go @@ -0,0 +1,95 @@ +package dns + +import ( + "net" + "testing" + + "github.com/pkg/errors" + + "github.com/gojuno/minimock" + + "github.com/maxatome/go-testdeep" + "github.com/rekby/lets-proxy2/internal/th" +) + +var ( + _ ResolverInterface = Parallel{} +) + +func TestParallel(t *testing.T) { + ctx, cancel := th.TestContext() + defer cancel() + + td := testdeep.NewT(t) + mc := minimock.NewController(td) + + var ips []net.IPAddr + var err error + + p := NewParallel() + ips, err = p.LookupIPAddr(ctx, "123") + td.CmpNoError(err) + td.Nil(ips) + + r1 := NewResolverInterfaceMock(mc) + r2 := NewResolverInterfaceMock(mc) + + p = NewParallel(r1) + + r1.LookupIPAddrMock.When(ctx, "1").Then([]net.IPAddr{{IP: net.ParseIP("1.2.3.4")}}, nil) + ips, err = p.LookupIPAddr(ctx, "1") + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.ParseIP("1.2.3.4")}}) + + testErr := errors.New("test2") + r1.LookupIPAddrMock.When(ctx, "2").Then(nil, testErr) + ips, err = p.LookupIPAddr(ctx, "2") + td.CmpDeeply(err, testErr) + td.Nil(ips) + + p = NewParallel(r1, r2) + r1.LookupIPAddrMock.When(ctx, "3").Then([]net.IPAddr{{IP: net.ParseIP("1.2.3.4")}}, nil) + r2.LookupIPAddrMock.When(ctx, "3").Then([]net.IPAddr{{IP: net.ParseIP("4.5.6.7")}}, nil) + ips, err = p.LookupIPAddr(ctx, "3") + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.ParseIP("1.2.3.4")}, {IP: net.ParseIP("4.5.6.7")}}) + + r1.LookupIPAddrMock.When(ctx, "4").Then([]net.IPAddr{{IP: net.ParseIP("1.2.3.4")}}, nil) + r2.LookupIPAddrMock.When(ctx, "4").Then(nil, errors.New("test4")) + ips, err = p.LookupIPAddr(ctx, "4") + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.ParseIP("1.2.3.4")}}) + + r1.LookupIPAddrMock.When(ctx, "5").Then(nil, errors.New("test5")) + r2.LookupIPAddrMock.When(ctx, "5").Then([]net.IPAddr{{IP: net.ParseIP("4.5.6.7")}}, nil) + ips, err = p.LookupIPAddr(ctx, "5") + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.ParseIP("4.5.6.7")}}) + + error61 := errors.New("test6-1") + error62 := errors.New("test6-2") + r1.LookupIPAddrMock.When(ctx, "6").Then(nil, error61) + r2.LookupIPAddrMock.When(ctx, "6").Then(nil, error62) + ips, err = p.LookupIPAddr(ctx, "6") + td.Any(err, []interface{}{error61, error62}) + td.Nil(ips) + +} + +func TestParallelReadl(t *testing.T) { + ctx, cancel := th.TestContext() + defer cancel() + + td := testdeep.NewT(t) + + r := NewParallel(NewResolver("8.8.8.8:53"), NewResolver("4.4.4.4:53")) + ips, err := r.LookupIPAddr(ctx, "one.one.one.one") + td.CmpNoError(err) + td.Contains(ips, + testdeep.Any( + net.IPAddr{IP: net.IPv4(1, 1, 1, 1).To4()}, + net.IPAddr{IP: net.IPv4(1, 1, 1, 1).To16()}, + ), + ) + +} diff --git a/internal/dns/resolver.go b/internal/dns/resolver.go new file mode 100644 index 00000000..e7e44162 --- /dev/null +++ b/internal/dns/resolver.go @@ -0,0 +1,172 @@ +package dns + +import ( + "context" + "errors" + "net" + "strings" + "sync" + "time" + + zc "github.com/rekby/zapcontext" + + "github.com/rekby/lets-proxy2/internal/log" + "go.uber.org/zap" + + mdns "github.com/miekg/dns" +) + +var ( + errTruncatedResponse = errors.New("truncated answer") + errPanic = errors.New("panic") +) + +type mDNSClient interface { + Exchange(msg *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error) +} + +// Resolve IPs for A and AAAA records of domains +// it use direct dns query without cache +type Resolver struct { + udp mDNSClient + tcp mDNSClient + server string + maxDNSRecursionDeep int + lookupWithClient func(ctx context.Context, host string, server string, recordType uint16, recursion int, client mDNSClient) ([]net.IPAddr, error) +} + +// NewResolver return direct dns resolver +func NewResolver(dnsServer string) *Resolver { + return &Resolver{ + udp: &mdns.Client{Net: "udp"}, + tcp: &mdns.Client{Net: "tcp"}, + server: dnsServer, + maxDNSRecursionDeep: 10, + lookupWithClient: lookupWithClient, + } +} + +func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) { + ctx = zc.WithLogger(ctx, zc.L(ctx).With(zap.String("dns_server", r.server))) + if !strings.HasSuffix(host, ".") { + host += "." + } + + var wg sync.WaitGroup + + var ipAddrA, ipAddrAAAA []net.IPAddr + var errA, errAAAA error + + wg.Add(1) + go func() { + defer wg.Done() + errA = errPanic + ipAddrA, errA = r.lookup(ctx, host, mdns.TypeA) + }() + + wg.Add(1) + go func() { + defer wg.Done() + errAAAA = errPanic + ipAddrAAAA, errAAAA = r.lookup(ctx, host, mdns.TypeAAAA) + }() + + wg.Wait() + + var resultErr error + if errAAAA != nil { + resultErr = errAAAA + } + if errA != nil { + resultErr = errA + } + log.DebugErrorCtx(ctx, resultErr, "Host lookup", zap.NamedError("errA", errA), + zap.NamedError("errAAAA", errAAAA), zap.Any("ipAddrA", ipAddrA), + zap.Any("ipAddrAAAA", ipAddrAAAA)) + + if resultErr != nil { + return nil, resultErr + } + resultIPs := make([]net.IPAddr, len(ipAddrA)+len(ipAddrAAAA)) + copy(resultIPs, ipAddrA) + copy(resultIPs[len(ipAddrA):], ipAddrAAAA) + return resultIPs, nil +} + +func (r *Resolver) lookup(ctx context.Context, host string, recordType uint16) ([]net.IPAddr, error) { + res, err := r.lookupWithClient(ctx, host, r.server, recordType, r.maxDNSRecursionDeep, r.udp) + if err == errTruncatedResponse { + zc.L(ctx).Debug("fallback to tcp request") + res, err = r.lookupWithClient(ctx, host, r.server, recordType, r.maxDNSRecursionDeep, r.tcp) + } + return res, err +} + +func lookupWithClient(ctx context.Context, host string, server string, recordType uint16, recursion int, client mDNSClient) (ipResults []net.IPAddr, err error) { + logger := zc.L(ctx) + + if recursion <= 0 { + logger.Error("Max recursion while resolve domain") + return nil, errors.New("max recursion while resolve domain") + } + + if ctx.Err() != nil { + logger.Debug("Context canceled") + return nil, ctx.Err() + } + defer func() { + log.DebugError(logger, err, "Resolved ips", zap.Any("ipResults", ipResults), + zap.Uint16("record_type", recordType)) + }() + + var msdID uint16 + for msdID == 0 { + msdID = mdns.Id() + } + + msg := mdns.Msg{ + MsgHdr: mdns.MsgHdr{ + Id: msdID, + }, + Question: []mdns.Question{ + {Name: host, Qclass: mdns.ClassINET, Qtype: recordType}, + }, + } + msg.RecursionDesired = true + exchangeCompleted := make(chan struct{}) + + var dnsAnswer *mdns.Msg + go func() { + dnsAnswer, _, err = client.Exchange(&msg, server) + close(exchangeCompleted) + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-exchangeCompleted: + if dnsAnswer != nil && dnsAnswer.Truncated { + return nil, errTruncatedResponse + } + if err != nil { + return nil, err + } + var resIPs []net.IPAddr + for _, r := range dnsAnswer.Answer { + rType := r.Header().Rrtype + switch { + case rType == mdns.TypeA && rType == recordType: + resIPs = append(resIPs, net.IPAddr{IP: r.(*mdns.A).A}) + case rType == mdns.TypeAAAA && rType == recordType: + resIPs = append(resIPs, net.IPAddr{IP: r.(*mdns.AAAA).AAAA}) + case rType == mdns.TypeCNAME: + cname := r.(*mdns.CNAME) + zc.L(ctx).Debug("Receive CNAME record for domain.", zap.String("target", cname.Target)) + return lookupWithClient(ctx, cname.Target, server, recordType, recursion-1, client) + default: + // pass + } + } + return resIPs, nil + } +} diff --git a/internal/dns/resolver_interface_mock_test.go b/internal/dns/resolver_interface_mock_test.go new file mode 100644 index 00000000..105b4b21 --- /dev/null +++ b/internal/dns/resolver_interface_mock_test.go @@ -0,0 +1,243 @@ +package dns + +// DO NOT EDIT! +// The code below was generated with http://github.com/gojuno/minimock (dev) + +//go:generate minimock -i github.com/rekby/lets-proxy2/internal/dns.ResolverInterface -o ./resolver_interface_mock_test.go + +import ( + "sync/atomic" + "time" + + "context" + + "net" + + "github.com/gojuno/minimock" +) + +// ResolverInterfaceMock implements ResolverInterface +type ResolverInterfaceMock struct { + t minimock.Tester + + funcLookupIPAddr func(ctx context.Context, host string) (ia1 []net.IPAddr, err error) + afterLookupIPAddrCounter uint64 + beforeLookupIPAddrCounter uint64 + LookupIPAddrMock mResolverInterfaceMockLookupIPAddr +} + +// NewResolverInterfaceMock returns a mock for ResolverInterface +func NewResolverInterfaceMock(t minimock.Tester) *ResolverInterfaceMock { + m := &ResolverInterfaceMock{t: t} + if controller, ok := t.(minimock.MockController); ok { + controller.RegisterMocker(m) + } + m.LookupIPAddrMock = mResolverInterfaceMockLookupIPAddr{mock: m} + + return m +} + +type mResolverInterfaceMockLookupIPAddr struct { + mock *ResolverInterfaceMock + defaultExpectation *ResolverInterfaceMockLookupIPAddrExpectation + expectations []*ResolverInterfaceMockLookupIPAddrExpectation +} + +// ResolverInterfaceMockLookupIPAddrExpectation specifies expectation struct of the ResolverInterface.LookupIPAddr +type ResolverInterfaceMockLookupIPAddrExpectation struct { + mock *ResolverInterfaceMock + params *ResolverInterfaceMockLookupIPAddrParams + results *ResolverInterfaceMockLookupIPAddrResults + Counter uint64 +} + +// ResolverInterfaceMockLookupIPAddrParams contains parameters of the ResolverInterface.LookupIPAddr +type ResolverInterfaceMockLookupIPAddrParams struct { + ctx context.Context + host string +} + +// ResolverInterfaceMockLookupIPAddrResults contains results of the ResolverInterface.LookupIPAddr +type ResolverInterfaceMockLookupIPAddrResults struct { + ia1 []net.IPAddr + err error +} + +// Expect sets up expected params for ResolverInterface.LookupIPAddr +func (m *mResolverInterfaceMockLookupIPAddr) Expect(ctx context.Context, host string) *mResolverInterfaceMockLookupIPAddr { + if m.mock.funcLookupIPAddr != nil { + m.mock.t.Fatalf("ResolverInterfaceMock.LookupIPAddr mock is already set by Set") + } + + if m.defaultExpectation == nil { + m.defaultExpectation = &ResolverInterfaceMockLookupIPAddrExpectation{} + } + + m.defaultExpectation.params = &ResolverInterfaceMockLookupIPAddrParams{ctx, host} + for _, e := range m.expectations { + if minimock.Equal(e.params, m.defaultExpectation.params) { + m.mock.t.Fatalf("Expectation set by When has same params: %#v", *m.defaultExpectation.params) + } + } + + return m +} + +// Return sets up results that will be returned by ResolverInterface.LookupIPAddr +func (m *mResolverInterfaceMockLookupIPAddr) Return(ia1 []net.IPAddr, err error) *ResolverInterfaceMock { + if m.mock.funcLookupIPAddr != nil { + m.mock.t.Fatalf("ResolverInterfaceMock.LookupIPAddr mock is already set by Set") + } + + if m.defaultExpectation == nil { + m.defaultExpectation = &ResolverInterfaceMockLookupIPAddrExpectation{mock: m.mock} + } + m.defaultExpectation.results = &ResolverInterfaceMockLookupIPAddrResults{ia1, err} + return m.mock +} + +//Set uses given function f to mock the ResolverInterface.LookupIPAddr method +func (m *mResolverInterfaceMockLookupIPAddr) Set(f func(ctx context.Context, host string) (ia1 []net.IPAddr, err error)) *ResolverInterfaceMock { + if m.defaultExpectation != nil { + m.mock.t.Fatalf("Default expectation is already set for the ResolverInterface.LookupIPAddr method") + } + + if len(m.expectations) > 0 { + m.mock.t.Fatalf("Some expectations are already set for the ResolverInterface.LookupIPAddr method") + } + + m.mock.funcLookupIPAddr = f + return m.mock +} + +// When sets expectation for the ResolverInterface.LookupIPAddr which will trigger the result defined by the following +// Then helper +func (m *mResolverInterfaceMockLookupIPAddr) When(ctx context.Context, host string) *ResolverInterfaceMockLookupIPAddrExpectation { + if m.mock.funcLookupIPAddr != nil { + m.mock.t.Fatalf("ResolverInterfaceMock.LookupIPAddr mock is already set by Set") + } + + expectation := &ResolverInterfaceMockLookupIPAddrExpectation{ + mock: m.mock, + params: &ResolverInterfaceMockLookupIPAddrParams{ctx, host}, + } + m.expectations = append(m.expectations, expectation) + return expectation +} + +// Then sets up ResolverInterface.LookupIPAddr return parameters for the expectation previously defined by the When method +func (e *ResolverInterfaceMockLookupIPAddrExpectation) Then(ia1 []net.IPAddr, err error) *ResolverInterfaceMock { + e.results = &ResolverInterfaceMockLookupIPAddrResults{ia1, err} + return e.mock +} + +// LookupIPAddr implements ResolverInterface +func (m *ResolverInterfaceMock) LookupIPAddr(ctx context.Context, host string) (ia1 []net.IPAddr, err error) { + atomic.AddUint64(&m.beforeLookupIPAddrCounter, 1) + defer atomic.AddUint64(&m.afterLookupIPAddrCounter, 1) + + for _, e := range m.LookupIPAddrMock.expectations { + if minimock.Equal(*e.params, ResolverInterfaceMockLookupIPAddrParams{ctx, host}) { + atomic.AddUint64(&e.Counter, 1) + return e.results.ia1, e.results.err + } + } + + if m.LookupIPAddrMock.defaultExpectation != nil { + atomic.AddUint64(&m.LookupIPAddrMock.defaultExpectation.Counter, 1) + want := m.LookupIPAddrMock.defaultExpectation.params + got := ResolverInterfaceMockLookupIPAddrParams{ctx, host} + if want != nil && !minimock.Equal(*want, got) { + m.t.Errorf("ResolverInterfaceMock.LookupIPAddr got unexpected parameters, want: %#v, got: %#v%s\n", *want, got, minimock.Diff(*want, got)) + } + + results := m.LookupIPAddrMock.defaultExpectation.results + if results == nil { + m.t.Fatal("No results are set for the ResolverInterfaceMock.LookupIPAddr") + } + return (*results).ia1, (*results).err + } + if m.funcLookupIPAddr != nil { + return m.funcLookupIPAddr(ctx, host) + } + m.t.Fatalf("Unexpected call to ResolverInterfaceMock.LookupIPAddr. %v %v", ctx, host) + return +} + +// LookupIPAddrAfterCounter returns a count of finished ResolverInterfaceMock.LookupIPAddr invocations +func (m *ResolverInterfaceMock) LookupIPAddrAfterCounter() uint64 { + return atomic.LoadUint64(&m.afterLookupIPAddrCounter) +} + +// LookupIPAddrBeforeCounter returns a count of ResolverInterfaceMock.LookupIPAddr invocations +func (m *ResolverInterfaceMock) LookupIPAddrBeforeCounter() uint64 { + return atomic.LoadUint64(&m.beforeLookupIPAddrCounter) +} + +// MinimockLookupIPAddrDone returns true if the count of the LookupIPAddr invocations corresponds +// the number of defined expectations +func (m *ResolverInterfaceMock) MinimockLookupIPAddrDone() bool { + for _, e := range m.LookupIPAddrMock.expectations { + if atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + // if default expectation was set then invocations count should be greater than zero + if m.LookupIPAddrMock.defaultExpectation != nil && atomic.LoadUint64(&m.afterLookupIPAddrCounter) < 1 { + return false + } + // if func was set then invocations count should be greater than zero + if m.funcLookupIPAddr != nil && atomic.LoadUint64(&m.afterLookupIPAddrCounter) < 1 { + return false + } + return true +} + +// MinimockLookupIPAddrInspect logs each unmet expectation +func (m *ResolverInterfaceMock) MinimockLookupIPAddrInspect() { + for _, e := range m.LookupIPAddrMock.expectations { + if atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to ResolverInterfaceMock.LookupIPAddr with params: %#v", *e.params) + } + } + + // if default expectation was set then invocations count should be greater than zero + if m.LookupIPAddrMock.defaultExpectation != nil && atomic.LoadUint64(&m.afterLookupIPAddrCounter) < 1 { + m.t.Errorf("Expected call to ResolverInterfaceMock.LookupIPAddr with params: %#v", *m.LookupIPAddrMock.defaultExpectation.params) + } + // if func was set then invocations count should be greater than zero + if m.funcLookupIPAddr != nil && atomic.LoadUint64(&m.afterLookupIPAddrCounter) < 1 { + m.t.Error("Expected call to ResolverInterfaceMock.LookupIPAddr") + } +} + +// MinimockFinish checks that all mocked methods have been called the expected number of times +func (m *ResolverInterfaceMock) MinimockFinish() { + if !m.minimockDone() { + m.MinimockLookupIPAddrInspect() + m.t.FailNow() + } +} + +// MinimockWait waits for all mocked methods to be called the expected number of times +func (m *ResolverInterfaceMock) MinimockWait(timeout time.Duration) { + timeoutCh := time.After(timeout) + for { + if m.minimockDone() { + return + } + select { + case <-timeoutCh: + m.MinimockFinish() + return + case <-time.After(10 * time.Millisecond): + } + } +} + +func (m *ResolverInterfaceMock) minimockDone() bool { + done := true + return done && + m.MinimockLookupIPAddrDone() +} diff --git a/internal/dns/resolver_test.go b/internal/dns/resolver_test.go new file mode 100644 index 00000000..b7f4f19a --- /dev/null +++ b/internal/dns/resolver_test.go @@ -0,0 +1,405 @@ +package dns + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/gojuno/minimock" + + "github.com/rekby/lets-proxy2/internal/th" + + "github.com/maxatome/go-testdeep" + mdns "github.com/miekg/dns" +) + +var ( + _ ResolverInterface = &Resolver{} +) + +func TestNewResolver(t *testing.T) { + td := testdeep.NewT(t) + resolver := NewResolver("1.2.3.4:53") + td.CmpDeeply(resolver.server, "1.2.3.4:53") + td.CmpDeeply(resolver.maxDNSRecursionDeep, 10) + td.NotNil(resolver.tcp) + td.NotNil(resolver.udp) + td.NotNil(resolver.lookupWithClient) +} + +func TestLookupWithClient(t *testing.T) { + ctx, cancel := th.TestContext() + defer cancel() + + td := testdeep.NewT(t) + mc := minimock.NewController(td) + + client := NewMDNSClientMock(mc) + client.ExchangeMock.Set(func(m *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error) { + if m.Id == 0 { + td.Error("Unset msg id") + } + switch m.Question[0].Qtype { + case mdns.TypeA: + switch m.Question[0].Name { + case "alias.com": + td.CmpDeeply(m.Question, []mdns.Question{{Name: "alias.com", Qtype: mdns.TypeA, Qclass: mdns.ClassINET}}) + return &mdns.Msg{ + MsgHdr: mdns.MsgHdr{Id: m.Id}, + Answer: []mdns.RR{ + &mdns.CNAME{ + Hdr: mdns.RR_Header{ + Rrtype: mdns.TypeCNAME, + }, + Target: "alias-target.com", + }, + }, + }, 0, nil + case "alias-target.com": + td.CmpDeeply(m.Question, []mdns.Question{{Name: "alias-target.com", Qtype: mdns.TypeA, Qclass: mdns.ClassINET}}) + return &mdns.Msg{ + MsgHdr: mdns.MsgHdr{Id: m.Id}, + Answer: []mdns.RR{ + &mdns.A{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeA}, + A: net.IPv4(3, 4, 5, 6), + }, + }, + }, 0, nil + case "asd.com": + td.CmpDeeply(m.Question, []mdns.Question{{Name: "asd.com", Qtype: mdns.TypeA, Qclass: mdns.ClassINET}}) + return &mdns.Msg{ + MsgHdr: mdns.MsgHdr{Id: m.Id}, + Answer: []mdns.RR{ + &mdns.A{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeA}, + A: net.IPv4(1, 2, 3, 4), + }, + &mdns.A{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeA}, + A: net.IPv4(5, 6, 7, 8), + }, + }, + }, 0, nil + default: + td.Error("Unexpected domain") + return nil, 0, errors.New("unexpected domain") + } + case mdns.TypeAAAA: + td.CmpDeeply(m.Question, []mdns.Question{{Name: "asd.com", Qtype: mdns.TypeAAAA, Qclass: mdns.ClassINET}}) + return &mdns.Msg{ + MsgHdr: mdns.MsgHdr{Id: m.Id}, + Answer: []mdns.RR{ + &mdns.AAAA{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeAAAA}, + AAAA: net.ParseIP("::1a"), + }, + &mdns.AAAA{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeAAAA}, + AAAA: net.ParseIP("::1f"), + }, + }, + }, 0, nil + case mdns.TypeAVC: + return &mdns.Msg{ + MsgHdr: mdns.MsgHdr{Id: m.Id}, + Answer: []mdns.RR{ + &mdns.AVC{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeAVC}, + Txt: []string{"aaa"}, + }, + }, + }, 0, nil + + default: + td.Error("Unexpected record type", m.Question) + return nil, 0, errors.New("unexpected") + } + }) + + ips, err := lookupWithClient(ctx, "asd.com", "1.2.3.4:53", mdns.TypeA, 1, client) + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.IPv4(1, 2, 3, 4)}, {IP: net.IPv4(5, 6, 7, 8)}}) + + ips, err = lookupWithClient(ctx, "asd.com", "1.2.3.4:53", mdns.TypeAAAA, 1, client) + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.ParseIP("::1a")}, {IP: net.ParseIP("::1f")}}) + + ips, err = lookupWithClient(ctx, "alias.com", "1.2.3.4:53", mdns.TypeA, 2, client) + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.IPv4(3, 4, 5, 6)}}) + + ips, err = lookupWithClient(ctx, "asd.com", "1.2.3.4:53", mdns.TypeAVC, 1, client) + td.CmpNoError(err) + td.Nil(ips) + + ips, err = lookupWithClient(ctx, "asd.com", "1.2.3.4:53", mdns.TypeAVC, 0, client) + td.CmpError(err) + td.Nil(ips) + + client.ExchangeMock.Set(func(msg *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error) { + return &mdns.Msg{MsgHdr: mdns.MsgHdr{Id: msg.Id, Truncated: true}}, 0, nil + }) + ips, err = lookupWithClient(ctx, "asd.com", "1.2.3.4:53", mdns.TypeA, 1, client) + td.CmpDeeply(err, errTruncatedResponse) + td.Nil(ips) + + client.ExchangeMock.Set(func(msg *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error) { + return &mdns.Msg{MsgHdr: mdns.MsgHdr{Id: msg.Id, Truncated: true}}, 0, errors.New("asd") + }) + ips, err = lookupWithClient(ctx, "asd.com", "1.2.3.4:53", mdns.TypeA, 1, client) + td.CmpDeeply(err, errTruncatedResponse) + td.Nil(ips) + + client.ExchangeMock.Set(func(msg *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error) { + return &mdns.Msg{MsgHdr: mdns.MsgHdr{Id: msg.Id}}, 0, errors.New("asd") + }) + ips, err = lookupWithClient(ctx, "asd.com", "1.2.3.4:53", mdns.TypeA, 1, client) + td.CmpDeeply(err, errors.New("asd")) + td.Nil(ips) + + client.ExchangeMock.Set(func(msg *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error) { + return &mdns.Msg{ + MsgHdr: mdns.MsgHdr{Id: msg.Id}, + Answer: []mdns.RR{ + &mdns.AAAA{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeAAAA}, + AAAA: net.ParseIP("::1a"), + }, + &mdns.AAAA{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeAAAA}, + AAAA: net.ParseIP("::1f"), + }, + &mdns.A{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeA}, + A: net.IPv4(1, 2, 3, 4), + }, + &mdns.A{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeA}, + A: net.IPv4(5, 6, 7, 8), + }, + }, + }, 0, nil + }) + ips, err = lookupWithClient(ctx, "asd.com", "1.2.3.4:53", mdns.TypeA, 1, client) + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.IPv4(1, 2, 3, 4)}, {IP: net.IPv4(5, 6, 7, 8)}}) + + client.ExchangeMock.Set(func(msg *mdns.Msg, address string) (r *mdns.Msg, rtt time.Duration, err error) { + time.Sleep(time.Second) + return &mdns.Msg{ + MsgHdr: mdns.MsgHdr{Id: msg.Id}, + Answer: []mdns.RR{ + &mdns.A{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeA}, + A: net.IPv4(1, 2, 3, 4), + }, + &mdns.A{ + Hdr: mdns.RR_Header{Rrtype: mdns.TypeA}, + A: net.IPv4(5, 6, 7, 8), + }, + }, + }, 0, nil + }) + timeoutCtx, timeoutCancelCtx := context.WithTimeout(ctx, time.Millisecond*10) + defer timeoutCancelCtx() + ips, err = lookupWithClient(timeoutCtx, "asd.com", "1.2.3.4:53", mdns.TypeA, 1, client) + td.CmpError(err) + td.Nil(ips) + + timeoutCancelled, timeoutCancelledCancelCtx := context.WithCancel(ctx) + timeoutCancelledCancelCtx() + ips, err = lookupWithClient(timeoutCancelled, "asd.com", "1.2.3.4:53", mdns.TypeA, 1, client) + td.CmpError(err) + td.Nil(ips) +} + +func TestResolver_Lookup(t *testing.T) { + ctx, cancel := th.TestContext() + defer cancel() + + td := testdeep.NewT(t) + mc := minimock.NewController(td) + + clientUDP := NewMDNSClientMock(mc) + clientTCP := NewMDNSClientMock(mc) + + resolver := &Resolver{ + udp: clientUDP, + tcp: clientTCP, + server: "dns", + maxDNSRecursionDeep: 13, + } + + resolver.lookupWithClient = func(ctx context.Context, host string, server string, recordType uint16, recursion int, client mDNSClient) (addrs []net.IPAddr, e error) { + td.CmpDeeply(recursion, 13) + answer, _, err := client.Exchange(&mdns.Msg{Question: []mdns.Question{ + {Name: host, Qtype: recordType}, + }}, server) + if err != nil { + return nil, err + } + return []net.IPAddr{{IP: answer.Answer[0].(*mdns.A).A}}, nil + } + + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "1", Qtype: mdns.TypeA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.A{A: net.IPv4(1, 2, 3, 4)}}}, 0, nil) + ips, err := resolver.lookup(ctx, "1", mdns.TypeA) + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.IPv4(1, 2, 3, 4)}}) + + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "2", Qtype: mdns.TypeA}, + }}, "dns"). + Then(nil, 0, errTruncatedResponse) + clientTCP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "2", Qtype: mdns.TypeA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.A{A: net.IPv4(1, 2, 3, 4)}}}, 0, nil) + ips, err = resolver.lookup(ctx, "2", mdns.TypeA) + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.IPv4(1, 2, 3, 4)}}) + + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "3", Qtype: mdns.TypeA}, + }}, "dns"). + Then(nil, 0, errors.New("test3")) + ips, err = resolver.lookup(ctx, "3", mdns.TypeA) + td.CmpDeeply(err, errors.New("test3")) + td.Nil(ips) + + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "4", Qtype: mdns.TypeA}, + }}, "dns"). + Then(nil, 0, errTruncatedResponse) + clientTCP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "4", Qtype: mdns.TypeA}, + }}, "dns"). + Then(nil, 0, errors.New("test4")) + ips, err = resolver.lookup(ctx, "4", mdns.TypeA) + td.CmpDeeply(err, errors.New("test4")) + td.Nil(ips) +} + +func TestResolver_LookupIPAddr(t *testing.T) { + ctx, cancel := th.TestContext() + defer cancel() + + td := testdeep.NewT(t) + mc := minimock.NewController(td) + + clientUDP := NewMDNSClientMock(mc) + clientTCP := NewMDNSClientMock(mc) + + resolver := &Resolver{ + udp: clientUDP, + tcp: clientTCP, + server: "dns", + maxDNSRecursionDeep: 13, + } + + resolver.lookupWithClient = func(ctx context.Context, host string, server string, recordType uint16, recursion int, client mDNSClient) (addrs []net.IPAddr, e error) { + td.CmpDeeply(recursion, 13) + dnsAnswer, _, err := client.Exchange(&mdns.Msg{Question: []mdns.Question{ + {Name: host, Qtype: recordType}, + }}, server) + if err != nil { + return nil, err + } + var resIPs []net.IPAddr + for _, r := range dnsAnswer.Answer { + switch recordType { + case mdns.TypeA: + resIPs = append(resIPs, net.IPAddr{IP: r.(*mdns.A).A}) + case mdns.TypeAAAA: + resIPs = append(resIPs, net.IPAddr{IP: r.(*mdns.AAAA).AAAA}) + default: + // pass + } + } + return resIPs, nil + } + + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "1.", Qtype: mdns.TypeA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.A{A: net.IPv4(1, 2, 3, 4)}}}, 0, nil) + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "1.", Qtype: mdns.TypeAAAA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.AAAA{AAAA: net.ParseIP("::bb")}}}, 0, nil) + ips, err := resolver.LookupIPAddr(ctx, "1") + td.CmpNoError(err) + td.CmpDeeply(ips, []net.IPAddr{{IP: net.IPv4(1, 2, 3, 4)}, {IP: net.ParseIP("::bb")}}) + + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "2.", Qtype: mdns.TypeA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.A{A: net.IPv4(1, 2, 3, 4)}}}, 0, errors.New("err2-1")) + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "2.", Qtype: mdns.TypeAAAA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.AAAA{AAAA: net.ParseIP("::bb")}}}, 0, nil) + ips, err = resolver.LookupIPAddr(ctx, "2") + td.CmpDeeply(err, errors.New("err2-1")) + td.Nil(ips) + + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "3.", Qtype: mdns.TypeA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.A{A: net.IPv4(1, 2, 3, 4)}}}, 0, errors.New("err3-1")) + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "3.", Qtype: mdns.TypeAAAA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.AAAA{AAAA: net.ParseIP("::bb")}}}, 0, errors.New("err3-2")) + ips, err = resolver.LookupIPAddr(ctx, "3") + td.CmpDeeply(err, errors.New("err3-1")) + td.Nil(ips) + + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "4.", Qtype: mdns.TypeA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.A{A: net.IPv4(1, 2, 3, 4)}}}, 0, nil) + clientUDP.ExchangeMock.When(&mdns.Msg{Question: []mdns.Question{ + {Name: "4.", Qtype: mdns.TypeAAAA}, + }}, "dns"). + Then(&mdns.Msg{Answer: []mdns.RR{&mdns.AAAA{AAAA: net.ParseIP("::bb")}}}, 0, errors.New("err4-2")) + ips, err = resolver.LookupIPAddr(ctx, "4") + td.CmpDeeply(err, errors.New("err4-2")) + td.Nil(ips) +} + +func TestResolverReal(t *testing.T) { + ctx, cancel := th.TestContext() + defer cancel() + + td := testdeep.NewT(t) + + var ips []net.IPAddr + var err error + + r := NewResolver("8.8.8.8:53") + ips, err = r.LookupIPAddr(ctx, "one.one.one.one") + td.CmpNoError(err) + td.Contains(ips, + testdeep.Any( + net.IPAddr{IP: net.IPv4(1, 1, 1, 1).To4()}, + net.IPAddr{IP: net.IPv4(1, 1, 1, 1).To16()}, + ), + ) + + r = NewResolver("1.1.1.1:53") + ips, err = r.LookupIPAddr(ctx, "qwe.l.rekby.ru") + td.CmpNoError(err) + td.Contains(ips, + testdeep.Any( + net.IPAddr{IP: net.ParseIP("127.0.0.1").To4()}, + net.IPAddr{IP: net.ParseIP("127.0.0.1").To16()}, + ), + ) + +} diff --git a/internal/domain_checker/config.go b/internal/domain_checker/config.go index 5eac3cc0..f3fcdf7e 100644 --- a/internal/domain_checker/config.go +++ b/internal/domain_checker/config.go @@ -5,6 +5,11 @@ import ( "context" "net" "regexp" + "strings" + + "github.com/pkg/errors" + + "github.com/rekby/lets-proxy2/internal/dns" zc "github.com/rekby/zapcontext" @@ -13,10 +18,11 @@ import ( ) type Config struct { - IPSelf bool `default:"true" comment:"Allow domain if it resolver for one of public IPs of this server."` - IPWhiteList string `default:"" comment:"Allow domain if it resolver for one of the ips."` - BlackList string `default:"" comment:"Regexp in golang syntax of blacklisted domain for issue certificate.\nThis list overrided by whitelist."` - WhiteList string `default:"" comment:"Regexp in golang syntax of whitelist domains for issue certificate.\nWhitelist need for allow part of domains, which excluded by blacklist.\n"` + IPSelf bool + IPWhiteList string + BlackList string + WhiteList string + Resolver string } func (c *Config) CreateDomainChecker(ctx context.Context) (DomainChecker, error) { @@ -42,6 +48,36 @@ func (c *Config) CreateDomainChecker(ctx context.Context) (DomainChecker, error) listCheckers = NewAny(listCheckers, NewRegexp(r)) } + var resolver Resolver + if strings.TrimSpace(c.Resolver) == "" { + resolver = net.DefaultResolver + } else { + stringAddresses := strings.Split(c.Resolver, ",") + var resolvers = make([]dns.ResolverInterface, 0, len(stringAddresses)) + for _, addr := range stringAddresses { + addr = strings.TrimSpace(addr) + if addr == "" { + continue + } + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + logger.Error("Can't resolve dns server address string", zap.String("addr", addr), zap.Error(err)) + return nil, err + } + if len(tcpAddr.IP) == 0 { + logger.Error("Can't resolve dns server address ip - it is empty.", zap.String("addr", addr)) + return nil, errors.New("empty ip address") + } + if tcpAddr.Port == 0 { + tcpAddr.Port = 53 // default dns port + } + tcpAddrString := tcpAddr.String() + resolvers = append(resolvers, dns.NewResolver(tcpAddrString)) + } + resolver = dns.NewParallel(resolvers...) + } + SetDefaultResolver(resolver) + var ipCheckers Any if c.IPSelf { diff --git a/internal/domain_checker/ip_list.go b/internal/domain_checker/ip_list.go index 8d7a2331..71f98633 100644 --- a/internal/domain_checker/ip_list.go +++ b/internal/domain_checker/ip_list.go @@ -50,8 +50,14 @@ var ( mustParseNet("fe80::/10"), mustParseNet("ff00::/8"), } + + defaultResolver Resolver = net.DefaultResolver ) +func SetDefaultResolver(resolver Resolver) { + defaultResolver = resolver +} + type IPList struct { Addresses AllowedIPAddresses Resolver Resolver @@ -75,7 +81,7 @@ func NewIPList(ctx context.Context, addresses AllowedIPAddresses) *IPList { res := &IPList{ ctx: ctx, Addresses: addresses, - Resolver: net.DefaultResolver, + Resolver: defaultResolver, AutoUpdateInterval: time.Hour, } res.updateIPs() @@ -94,6 +100,11 @@ func (s *IPList) IsDomainAllowed(ctx context.Context, domain string) (bool, erro return false, err } + if len(ips) == 0 { + logger.Info("Doesn't allow domain without ip address") + return false, errors.New("domain has no ip address") + } + s.mu.RLock() defer s.mu.RUnlock() diff --git a/internal/domain_checker/ip_list_test.go b/internal/domain_checker/ip_list_test.go index 72a963d8..7f086d5f 100644 --- a/internal/domain_checker/ip_list_test.go +++ b/internal/domain_checker/ip_list_test.go @@ -185,6 +185,17 @@ func TestIPList_UpdateByTimer(t *testing.T) { time.Sleep(50 * time.Millisecond) } +func TestSetDefaultResolver(t *testing.T) { + oldResolver := defaultResolver + defer func() { + defaultResolver = oldResolver + }() + + resolver := NewResolverMock(t) + SetDefaultResolver(resolver) + testdeep.CmpDeeply(t, defaultResolver, resolver) +} + func TestSelfPublicIP_IsDomainAllowed(t *testing.T) { var _ DomainChecker = &IPList{} diff --git a/vendor/modules.txt b/vendor/modules.txt index b8be50bb..6c03cf08 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -24,6 +24,8 @@ github.com/maxatome/go-testdeep/internal/dark github.com/maxatome/go-testdeep/internal/location github.com/maxatome/go-testdeep/internal/types github.com/maxatome/go-testdeep/internal/util +# github.com/miekg/dns v1.1.12 +github.com/miekg/dns # github.com/pkg/errors v0.8.1 github.com/pkg/errors # github.com/pmezard/go-difflib v1.0.0 @@ -49,10 +51,18 @@ go.uber.org/zap/internal/color go.uber.org/zap/internal/exit # golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 golang.org/x/crypto/acme +golang.org/x/crypto/ed25519 +golang.org/x/crypto/ed25519/internal/edwards25519 # golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6 golang.org/x/net/idna +golang.org/x/net/ipv4 +golang.org/x/net/ipv6 +golang.org/x/net/bpf +golang.org/x/net/internal/iana +golang.org/x/net/internal/socket # golang.org/x/sys v0.0.0-20190428183149-804c0c7841b5 golang.org/x/sys/windows/svc +golang.org/x/sys/unix golang.org/x/sys/windows # golang.org/x/text v0.3.2 golang.org/x/text/secure/bidirule