File dnsproxy-0.75.0.obscpio of Package dnsproxy
07070100000000000081A4000000000000000000000001679A649F00000080000000000000000000000000000000000000001D00000000dnsproxy-0.75.0/.codecov.ymlcoverage:
status:
project:
default:
target: 40%
threshold: null
patch: false
changes: false
07070100000001000081A4000000000000000000000001679A649F00000049000000000000000000000000000000000000001E00000000dnsproxy-0.75.0/.dockerignore# Ignore everything except for explicitly allowed stuff.
*
!build/docker
07070100000002000081A4000000000000000000000001679A649F00000011000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/.gitattributesvendor/** binary
07070100000003000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001800000000dnsproxy-0.75.0/.github07070100000004000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002200000000dnsproxy-0.75.0/.github/workflows07070100000005000081A4000000000000000000000001679A649F00000A62000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/.github/workflows/build.yamlname: Build
'env':
'GO_VERSION': '1.23.5'
'on':
'push':
'tags':
- 'v*'
'branches':
- '*'
'pull_request':
jobs:
tests:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os:
- windows-latest
- macos-latest
- ubuntu-latest
steps:
- uses: actions/checkout@master
- uses: actions/setup-go@v2
with:
go-version: '${{ env.GO_VERSION }}'
- name: Run tests
env:
CI: "1"
run: |-
make test
- name: Upload coverage
uses: codecov/codecov-action@v1
if: "success() && matrix.os == 'ubuntu-latest'"
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage.txt
build:
needs:
- tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- uses: actions/setup-go@v2
with:
go-version: '${{ env.GO_VERSION }}'
- name: Build release
run: |-
set -e -u -x
RELEASE_VERSION="${GITHUB_REF##*/}"
if [[ "${RELEASE_VERSION}" != v* ]]; then RELEASE_VERSION='dev'; fi
echo "RELEASE_VERSION=\"${RELEASE_VERSION}\"" >> $GITHUB_ENV
make VERBOSE=1 VERSION="${RELEASE_VERSION}" release
ls -l build/dnsproxy-*
- name: Create release
if: startsWith(github.ref, 'refs/tags/v')
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ github.ref }}
release_name: Release ${{ github.ref }}
draft: false
prerelease: false
- name: Upload
if: startsWith(github.ref, 'refs/tags/v')
uses: xresloader/upload-to-github-release@v1.3.12
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
file: "build/dnsproxy-*.tar.gz;build/dnsproxy-*.zip"
tags: true
draft: false
notify:
needs:
- build
if:
${{ always() &&
(
github.event_name == 'push' ||
github.event.pull_request.head.repo.full_name == github.repository
)
}}
runs-on: ubuntu-latest
steps:
- name: Conclusion
uses: technote-space/workflow-conclusion-action@v1
- name: Send Slack notif
uses: 8398a7/action-slack@v3
with:
status: ${{ env.WORKFLOW_CONCLUSION }}
fields: workflow, repo, message, commit, author, eventName,ref
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
07070100000006000081A4000000000000000000000001679A649F0000090E000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/.github/workflows/docker.yml'name': Docker
'env':
'GO_VERSION': '1.23.5'
'on':
'push':
'tags':
- 'v*'
# Builds from the master branch will be pushed with the `dev` tag.
'branches':
- 'master'
'jobs':
'docker':
'runs-on': 'ubuntu-latest'
'steps':
- 'name': 'Checkout'
'uses': 'actions/checkout@v3'
'with':
'fetch-depth': 0
- 'name': 'Set up Go'
'uses': 'actions/setup-go@v3'
'with':
'go-version': '${{ env.GO_VERSION }}'
- 'name': 'Set up Go modules cache'
'uses': 'actions/cache@v2'
'with':
'path': '~/go/pkg/mod'
'key': "${{ runner.os }}-go-${{ hashFiles('go.sum') }}"
'restore-keys': '${{ runner.os }}-go-'
- 'name': 'Set up QEMU'
'uses': 'docker/setup-qemu-action@v1'
- 'name': 'Set up Docker Buildx'
'uses': 'docker/setup-buildx-action@v1'
- 'name': 'Publish to Docker Hub'
'env':
'DOCKER_USER': ${{ secrets.DOCKER_USER }}
'DOCKER_PASSWORD': ${{ secrets.DOCKER_PASSWORD }}
'run': |-
set -e -u -x
RELEASE_VERSION="${GITHUB_REF##*/}"
if [[ "${RELEASE_VERSION}" != v* ]]; then RELEASE_VERSION='dev'; fi
echo "RELEASE_VERSION=\"${RELEASE_VERSION}\"" >> $GITHUB_ENV
docker login \
-u="${DOCKER_USER}" \
-p="${DOCKER_PASSWORD}"
make \
VERSION="${RELEASE_VERSION}" \
DOCKER_IMAGE_NAME="adguard/dnsproxy" \
DOCKER_OUTPUT="type=image,name=adguard/dnsproxy,push=true" \
VERBOSE="1" \
docker
'notify':
'needs':
- 'docker'
'if':
${{ always() &&
(
github.event_name == 'push' ||
github.event.pull_request.head.repo.full_name == github.repository
)
}}
'runs-on': ubuntu-latest
'steps':
- 'name': Conclusion
'uses': technote-space/workflow-conclusion-action@v1
- 'name': Send Slack notif
'uses': 8398a7/action-slack@v3
'with':
'status': ${{ env.WORKFLOW_CONCLUSION }}
'fields': workflow, repo, message, commit, author, eventName,ref
'env':
'GITHUB_TOKEN': ${{ secrets.GITHUB_TOKEN }}
'SLACK_WEBHOOK_URL': ${{ secrets.SLACK_WEBHOOK_URL }}
07070100000007000081A4000000000000000000000001679A649F00000593000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/.github/workflows/lint.yaml'name': 'lint'
'env':
'GO_VERSION': '1.23.5'
'on':
'push':
'tags':
- 'v*'
'branches':
- '*'
'pull_request':
'jobs':
'go-lint':
'runs-on': 'ubuntu-latest'
'steps':
- 'uses': 'actions/checkout@v2'
- 'name': 'Set up Go'
'uses': 'actions/setup-go@v3'
'with':
'go-version': '${{ env.GO_VERSION }}'
- 'name': 'run-lint'
'run': >
make go-deps go-tools go-lint
'notify':
'needs':
- 'go-lint'
# Secrets are not passed to workflows that are triggered by a pull request
# from a fork.
#
# Use always() to signal to the runner that this job must run even if the
# previous ones failed.
'if':
${{
always() &&
github.repository_owner == 'AdguardTeam' &&
(
github.event_name == 'push' ||
github.event.pull_request.head.repo.full_name == github.repository
)
}}
'runs-on': 'ubuntu-latest'
'steps':
- 'name': 'Conclusion'
'uses': 'technote-space/workflow-conclusion-action@v1'
- 'name': 'Send Slack notif'
'uses': '8398a7/action-slack@v3'
'with':
'status': '${{ env.WORKFLOW_CONCLUSION }}'
'fields': 'workflow, repo, message, commit, author, eventName, ref'
'env':
'GITHUB_TOKEN': '${{ secrets.GITHUB_TOKEN }}'
'SLACK_WEBHOOK_URL': '${{ secrets.SLACK_WEBHOOK_URL }}'
07070100000008000081A4000000000000000000000001679A649F000001D3000000000000000000000000000000000000001B00000000dnsproxy-0.75.0/.gitignore# Please, DO NOT put your text editors' temporary files here. The more are
# added, the harder it gets to maintain and manage projects' gitignores. Put
# them into your global gitignore file instead.
#
# See https://stackoverflow.com/a/7335487/1892060.
#
# Only build, run, and test outputs here. Sorted. With negations at the
# bottom to make sure they take effect.
*.out
*.test
/bin/
build
dnsproxy
dnsproxy.exe
example.crt
example.key
coverage.txt
config.yaml
07070100000009000081A4000000000000000000000001679A649F00002C57000000000000000000000000000000000000001800000000dnsproxy-0.75.0/LICENSE
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020 Adguard Software Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
0707010000000A000081A4000000000000000000000001679A649F00000BDC000000000000000000000000000000000000001900000000dnsproxy-0.75.0/Makefile# Keep the Makefile POSIX-compliant. We currently allow hyphens in
# target names, but that may change in the future.
#
# See https://pubs.opengroup.org/onlinepubs/9799919799/utilities/make.html.
.POSIX:
# This comment is used to simplify checking local copies of the
# Makefile. Bump this number every time a significant change is made to
# this Makefile.
#
# AdGuard-Project-Version: 9
# Don't name these macros "GO" etc., because GNU Make apparently makes
# them exported environment variables with the literal value of
# "${GO:-go}" and so on, which is not what we need. Use a dot in the
# name to make sure that users don't have an environment variable with
# the same name.
#
# See https://unix.stackexchange.com/q/646255/105635.
GO.MACRO = $${GO:-go}
VERBOSE.MACRO = $${VERBOSE:-0}
BRANCH = $${BRANCH:-$$(git rev-parse --abbrev-ref HEAD)}
DIST_DIR = build
GOAMD64 = v1
GOPROXY = https://proxy.golang.org|direct
GOTOOLCHAIN = go1.23.5
GOTELEMETRY = off
OUT = dnsproxy
RACE = 0
REVISION = $${REVISION:-$$(git rev-parse --short HEAD)}
VERSION = 0
ENV = env\
BRANCH="$(BRANCH)"\
DIST_DIR='$(DIST_DIR)'\
GO="$(GO.MACRO)"\
GOAMD64='$(GOAMD64)'\
GOPROXY='$(GOPROXY)'\
GOTELEMETRY='$(GOTELEMETRY)'\
GOTOOLCHAIN='$(GOTOOLCHAIN)'\
OUT='$(OUT)'\
PATH="$${PWD}/bin:$$("$(GO.MACRO)" env GOPATH)/bin:$${PATH}"\
RACE='$(RACE)'\
REVISION="$(REVISION)"\
VERBOSE="$(VERBOSE.MACRO)"\
VERSION="$(VERSION)"\
# Keep the line above blank.
ENV_MISC = env\
PATH="$${PWD}/bin:$$("$(GO.MACRO)" env GOPATH)/bin:$${PATH}"\
VERBOSE="$(VERBOSE.MACRO)"\
# Keep the line above blank.
# Keep this target first, so that a naked make invocation triggers a
# full build.
build: go-deps go-build
init: ; git config core.hooksPath ./scripts/hooks
test: go-test
go-build: ; $(ENV) "$(SHELL)" ./scripts/make/go-build.sh
go-deps: ; $(ENV) "$(SHELL)" ./scripts/make/go-deps.sh
go-env: ; $(ENV) "$(GO.MACRO)" env
go-lint: ; $(ENV) "$(SHELL)" ./scripts/make/go-lint.sh
go-test: ; $(ENV) RACE='1' "$(SHELL)" ./scripts/make/go-test.sh
go-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-tools.sh
go-upd-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-upd-tools.sh
go-check: go-tools go-lint go-test
# A quick check to make sure that all operating systems relevant to the
# development of the project can be typechecked and built successfully.
go-os-check:
$(ENV) GOOS='darwin' "$(GO.MACRO)" vet ./...
$(ENV) GOOS='freebsd' "$(GO.MACRO)" vet ./...
$(ENV) GOOS='openbsd' "$(GO.MACRO)" vet ./...
$(ENV) GOOS='linux' "$(GO.MACRO)" vet ./...
$(ENV) GOOS='windows' "$(GO.MACRO)" vet ./...
txt-lint: ; $(ENV) "$(SHELL)" ./scripts/make/txt-lint.sh
md-lint: ; $(ENV_MISC) "$(SHELL)" ./scripts/make/md-lint.sh
sh-lint: ; $(ENV_MISC) "$(SHELL)" ./scripts/make/sh-lint.sh
clean: ; $(ENV) $(GO.MACRO) clean && rm -f -r '$(DIST_DIR)'
release: clean
$(ENV) "$(SHELL)" ./scripts/make/build-release.sh
docker: release
$(ENV) "$(SHELL)" ./scripts/make/build-docker.sh
0707010000000B000081A4000000000000000000000001679A649F00004ABE000000000000000000000000000000000000001A00000000dnsproxy-0.75.0/README.md[](https://codecov.io/github/AdguardTeam/dnsproxy?branch=master)
[](https://goreportcard.com/report/AdguardTeam/dnsproxy)
[](https://godoc.org/github.com/AdguardTeam/dnsproxy)
# DNS Proxy <!-- omit in toc -->
A simple DNS proxy server that supports all existing DNS protocols including
`DNS-over-TLS`, `DNS-over-HTTPS`, `DNSCrypt`, and `DNS-over-QUIC`. Moreover,
it can work as a `DNS-over-HTTPS`, `DNS-over-TLS` or `DNS-over-QUIC` server.
- [How to install](#how-to-install)
- [How to build](#how-to-build)
- [Usage](#usage)
- [Examples](#examples)
- [Simple options](#simple-options)
- [Encrypted upstreams](#encrypted-upstreams)
- [Encrypted DNS server](#encrypted-dns-server)
- [Additional features](#additional-features)
- [DNS64 server](#dns64-server)
- [Fastest addr + cache-min-ttl](#fastest-addr--cache-min-ttl)
- [Specifying upstreams for domains](#specifying-upstreams-for-domains)
- [EDNS Client Subnet](#edns-client-subnet)
- [Bogus NXDomain](#bogus-nxdomain)
## How to install
There are several options how to install `dnsproxy`.
1. Grab the binary for your device/OS from the [Releases][releases] page.
2. Use the [official Docker image][docker].
3. Build it yourself (see the instruction below).
[releases]: https://github.com/AdguardTeam/dnsproxy/releases
[docker]: https://hub.docker.com/r/adguard/dnsproxy
## How to build
You will need Go v1.21 or later.
```shell
$ make build
```
## Usage
```
Usage:
dnsproxy [OPTIONS]
Application Options:
--config-path= yaml configuration file. Minimal working configuration in config.yaml.dist. Options passed through command
line will override the ones from this file.
-o, --output= Path to the log file. If not set, write to stdout.
-c, --tls-crt= Path to a file with the certificate chain
-k, --tls-key= Path to a file with the private key
--https-server-name= Set the Server header for the responses from the HTTPS server. (default: dnsproxy)
--https-userinfo= If set, all DoH queries are required to have this basic authentication information.
-g, --dnscrypt-config= Path to a file with DNSCrypt configuration. You can generate one using https://github.com/ameshkov/dnscrypt
--edns-addr= Send EDNS Client Address
--upstream-mode= Defines the upstreams logic mode, possible values: load_balance, parallel, fastest_addr (default:
load_balance)
-l, --listen= Listening addresses
-p, --port= Listening ports. Zero value disables TCP and UDP listeners
-s, --https-port= Listening ports for DNS-over-HTTPS
-t, --tls-port= Listening ports for DNS-over-TLS
-q, --quic-port= Listening ports for DNS-over-QUIC
-y, --dnscrypt-port= Listening ports for DNSCrypt
-u, --upstream= An upstream to be used (can be specified multiple times). You can also specify path to a file with the
list of servers
-b, --bootstrap= Bootstrap DNS for DoH and DoT, can be specified multiple times (default: use system-provided)
-f, --fallback= Fallback resolvers to use when regular ones are unavailable, can be specified multiple times. You can also
specify path to a file with the list of servers
--private-rdns-upstream= Private DNS upstreams to use for reverse DNS lookups of private addresses, can be specified multiple times
--dns64-prefix= Prefix used to handle DNS64. If not specified, dnsproxy uses the 'Well-Known Prefix' 64:ff9b::. Can be
specified multiple times
--private-subnets= Private subnets to use for reverse DNS lookups of private addresses
--bogus-nxdomain= Transform the responses containing at least a single IP that matches specified addresses and CIDRs into
NXDOMAIN. Can be specified multiple times.
--hosts-files= List of paths to the hosts files, can be specified multiple times
--timeout= Timeout for outbound DNS queries to remote upstream servers in a human-readable form (default: 10s)
--cache-min-ttl= Minimum TTL value for DNS entries, in seconds. Capped at 3600. Artificially extending TTLs should only be
done with careful consideration.
--cache-max-ttl= Maximum TTL value for DNS entries, in seconds.
--cache-size= Cache size (in bytes). Default: 64k
-r, --ratelimit= Ratelimit (requests per second)
--ratelimit-subnet-len-ipv4= Ratelimit subnet length for IPv4. (default: 24)
--ratelimit-subnet-len-ipv6= Ratelimit subnet length for IPv6. (default: 56)
--udp-buf-size= Set the size of the UDP buffer in bytes. A value <= 0 will use the system default.
--max-go-routines= Set the maximum number of go routines. A zero value will not not set a maximum.
--tls-min-version= Minimum TLS version, for example 1.0
--tls-max-version= Maximum TLS version, for example 1.3
--pprof If present, exposes pprof information on localhost:6060.
--version Prints the program version
-v, --verbose Verbose output (optional)
--insecure Disable secure TLS certificate validation
--ipv6-disabled If specified, all AAAA requests will be replied with NoError RCode and empty answer
--http3 Enable HTTP/3 support
--cache-optimistic If specified, optimistic DNS cache is enabled
--cache If specified, DNS cache is enabled
--refuse-any If specified, refuse ANY requests
--edns Use EDNS Client Subnet extension
--dns64 If specified, dnsproxy will act as a DNS64 server
--use-private-rdns If specified, use private upstreams for reverse DNS lookups of private addresses
--hosts-file-enabled= If specified, use hosts files for resolving (default: true)
Help Options:
-h, --help Show this help message
```
## Examples
### Simple options
Runs a DNS proxy on `0.0.0.0:53` with a single upstream - Google DNS.
```shell
./dnsproxy -u 8.8.8.8:53
```
The same proxy with verbose logging enabled writing it to the file `log.txt`.
```shell
./dnsproxy -u 8.8.8.8:53 -v -o log.txt
```
Runs a DNS proxy on `127.0.0.1:5353` with multiple upstreams.
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8:53 -u 1.1.1.1:53
```
Listen on multiple interfaces and ports:
```shell
./dnsproxy -l 127.0.0.1 -l 192.168.1.10 -p 5353 -p 5354 -u 1.1.1.1
```
The plain DNS upstream server may be specified in several ways:
- With a plain IP address:
```shell
./dnsproxy -l 127.0.0.1 -u 8.8.8.8:53
```
- With a hostname or plain IP address and the `udp://` scheme:
```shell
./dnsproxy -l 127.0.0.1 -u udp://dns.google -u udp://1.1.1.1
```
- With a hostname or plain IP address and the `tcp://` scheme to force using
TCP:
```shell
./dnsproxy -l 127.0.0.1 -u tcp://dns.google -u tcp://1.1.1.1
```
### Encrypted upstreams
DNS-over-TLS upstream:
```shell
./dnsproxy -u tls://dns.adguard.com
```
DNS-over-HTTPS upstream with specified bootstrap DNS:
```shell
./dnsproxy -u https://dns.adguard.com/dns-query -b 1.1.1.1:53
```
DNS-over-QUIC upstream:
```shell
./dnsproxy -u quic://dns.adguard.com
```
DNS-over-HTTPS upstream with enabled HTTP/3 support (chooses it if it's faster):
```shell
./dnsproxy -u https://dns.google/dns-query --http3
```
DNS-over-HTTPS upstream with forced HTTP/3 (no fallback to other protocol):
```shell
./dnsproxy -u h3://dns.google/dns-query
```
DNSCrypt upstream ([DNS Stamp](https://dnscrypt.info/stamps) of AdGuard DNS):
```shell
./dnsproxy -u sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20
```
DNS-over-HTTPS upstream ([DNS Stamp](https://dnscrypt.info/stamps) of Cloudflare DNS):
```shell
./dnsproxy -u sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk
```
DNS-over-TLS upstream with two fallback servers (to be used when the main upstream is not available):
```shell
./dnsproxy -u tls://dns.adguard.com -f 8.8.8.8:53 -f 1.1.1.1:53
```
### Encrypted DNS server
Runs a DNS-over-TLS proxy on `127.0.0.1:853`.
```shell
./dnsproxy -l 127.0.0.1 --tls-port=853 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```
Runs a DNS-over-HTTPS proxy on `127.0.0.1:443`.
```shell
./dnsproxy -l 127.0.0.1 --https-port=443 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```
Runs a DNS-over-HTTPS proxy on `127.0.0.1:443` with HTTP/3 support.
```shell
./dnsproxy -l 127.0.0.1 --https-port=443 --http3 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```
Runs a DNS-over-QUIC proxy on `127.0.0.1:853`.
```shell
./dnsproxy -l 127.0.0.1 --quic-port=853 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```
Runs a DNSCrypt proxy on `127.0.0.1:443`.
```shell
./dnsproxy -l 127.0.0.1 --dnscrypt-config=./dnscrypt-config.yaml --dnscrypt-port=443 --upstream=8.8.8.8:53 -p 0
```
> Please note that in order to run a DNSCrypt proxy, you need to obtain DNSCrypt configuration first. You can use https://github.com/ameshkov/dnscrypt command-line tool to do that with a command like this `./dnscrypt generate --provider-name=2.dnscrypt-cert.example.org --out=dnscrypt-config.yaml`
### Additional features
Runs a DNS proxy on `0.0.0.0:53` with rate limit set to `10 rps`, enabled DNS cache, and that refuses type=ANY requests.
```shell
./dnsproxy -u 8.8.8.8:53 -r 10 --cache --refuse-any
```
Runs a DNS proxy on 127.0.0.1:5353 with multiple upstreams and enable parallel queries to all configured upstream servers.
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8:53 -u 1.1.1.1:53 -u tls://dns.adguard.com --upstream-mode parallel
```
Loads upstreams list from a file.
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u ./upstreams.txt
```
### DNS64 server
`dnsproxy` is capable of working as a DNS64 server.
> **What is DNS64/NAT64**
> This is a mechanism of providing IPv6 access to IPv4. Using a NAT64 gateway
> with IPv4-IPv6 translation capability lets IPv6-only clients connect to
> IPv4-only services via synthetic IPv6 addresses starting with a prefix that
> routes them to the NAT64 gateway. DNS64 is a DNS service that returns AAAA
> records with these synthetic IPv6 addresses for IPv4-only destinations
> (with A but not AAAA records in the DNS). This lets IPv6-only clients use
> NAT64 gateways without any other configuration.
See also [RFC 6147](https://datatracker.ietf.org/doc/html/rfc6147).
Enables DNS64 with the default [Well-Known Prefix][wkp]:
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8 --use-private-rdns --private-rdns-upstream=127.0.0.1 --dns64
```
You can also specify any number of custom DNS64 prefixes:
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8 --use-private-rdns --private-rdns-upstream=127.0.0.1 --dns64 --dns64-prefix=64:ffff:: --dns64-prefix=32:ffff::
```
Note that only the first specified prefix will be used for synthesis.
PTR queries for addresses within the specified ranges or the
[Well-Known one][wkp] could only be answered with locally appropriate data, so
dnsproxy will route those to the local upstream servers. Those should be
specified and enabled if DNS64 is enabled.
[wkp]: https://datatracker.ietf.org/doc/html/rfc6052#section-2.1
### Fastest addr + cache-min-ttl
This option would be useful to the users with problematic network connection.
In this mode, `dnsproxy` would detect the fastest IP address among all that were
returned, and it will return only it.
Additionally, for those with problematic network connection, it makes sense to
override `cache-min-ttl`. In this case, `dnsproxy` will make sure that DNS
responses are cached for at least the specified amount of time.
It makes sense to run it with multiple upstream servers only.
Run a DNS proxy with two upstreams, min-TTL set to 10 minutes, fastest address
detection is enabled:
```
./dnsproxy -u 8.8.8.8 -u 1.1.1.1 --cache --cache-min-ttl=600 --upstream-mode=fastest_addr
```
who run `dnsproxy` with multiple upstreams
### Specifying upstreams for domains
You can specify upstreams that will be used for a specific domain(s). We use the
dnsmasq-like syntax, decorating domains with brackets (see `--server`
[description][server-description]).
**Syntax:** `[/[domain1][/../domainN]/]upstreamString`
Where `upstreamString` is one or many upstreams separated by space (e.g.
`1.1.1.1` or `1.1.1.1 2.2.2.2`).
If one or more domains are specified, that upstream (`upstreamString`) is used
only for those domains. Usually, it is used for private nameservers. For
instance, if you have a nameserver on your network which deals with
`xxx.internal.local` at `192.168.0.1` then you can specify
`[/internal.local/]192.168.0.1`, and dnsproxy will send all queries to that
nameserver. Everything else will be sent to the default upstreams (which are
mandatory!).
1. An empty domain specification, `//` has the special meaning of "unqualified
names only", which will be used to resolve names with a single label in them,
or with exactly two labels in case of `DS` requests.
2. More specific domains take precedence over less specific domains, so:
`--upstream=[/host.com/]1.2.3.4 --upstream=[/www.host.com/]2.3.4.5` will send
queries for `*.host.com` to `1.2.3.4`, except `*.www.host.com`, which will go
to `2.3.4.5`.
3. The special server address `#` means, "use the common servers", so:
`--upstream=[/host.com/]1.2.3.4 --upstream=[/www.host.com/]#` will send
queries for `*.host.com` to `1.2.3.4`, except `*.www.host.com` which will be
forwarded as usual.
4. The wildcard `*` has special meaning of "any sub-domain", so:
`--upstream=[/*.host.com/]1.2.3.4` will send queries for `*.host.com` to
`1.2.3.4`, but `host.com` will be forwarded to default upstreams.
**Examples**
Sends requests for `*.local` domains to `192.168.0.1:53`. Other requests are
sent to `8.8.8.8:53`:
```sh
./dnsproxy\
-u "8.8.8.8:53"\
-u "[/local/]192.168.0.1:53"
```
Sends requests for `*.host.com` to `1.1.1.1:53` except for `*.maps.host.com`
which are sent to `8.8.8.8:53` (along with other requests):
```sh
./dnsproxy\
-u "8.8.8.8:53"\
-u "[/host.com/]1.1.1.1:53"\
-u "[/maps.host.com/]#"
```
Sends requests for `*.host.com` to `1.1.1.1:53` except for `host.com` which is
sent to `9.9.9.10:53`, and all other requests are sent to `8.8.8.8:53`:
```sh
./dnsproxy\
-u "8.8.8.8:53"\
-u "[/host.com/]9.9.9.10:53"\
-u "[/*.host.com/]1.1.1.1:53"
```
Sends requests for `com` (and its subdomains) to `1.2.3.4:53`, requests for
other top-level domains to `1.1.1.1:53`, and all other requests to `8.8.8.8:53`:
```sh
./dnsproxy\
-u "8.8.8.8:53"\
-u "[//]1.1.1.1:53"\
-u "[/com/]1.2.3.4:53"
```
### Specifying private rDNS upstreams
You can specify upstreams that will be used for reverse DNS requests of type PTR
for private addresses. Same applies to the authority requests of types SOA and
NS. The set of private addresses is defined by the `--private-rdns-upstream`,
and the set from [RFC 6303][rfc6303] is used by default.
The additional requirement to the domains specified for upstreams is to be
`in-addr.arpa`, `ip6.arpa`, or its subdomain. Addresses encoded in the domains
should also be private.
**Examples**
Sends queries for `*.168.192.in-addr.arpa` to `192.168.1.2`, if requested by
client from `192.168.0.0/16` subnet. Other queries answered with `NXDOMAIN`:
```sh
./dnsproxy\
-l "0.0.0.0"\
-u "8.8.8.8"\
--use-private-rdns\
--private-subnets="192.168.0.0/16"
--private-rdns-upstream="192.168.1.2"\
```
Sends queries for `*.in-addr.arpa` to `192.168.1.2`, `*.ip6.arpa` to `fe80::1`,
if requested by client within the default [RFC 6303][rfc6303] subnet set. Other
queries answered with `NXDOMAIN`:
```sh
./dnsproxy\
-l "0.0.0.0"\
-u 8.8.8.8\
--use-private-rdns\
--private-rdns-upstream="192.168.1.2"\
--private-rdns-upstream="[/ip6.arpa/]fe80::1"
```
[rfc6303]: https://datatracker.ietf.org/doc/html/rfc6303
[server-description]: http://www.thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
### EDNS Client Subnet
To enable support for EDNS Client Subnet extension you should run dnsproxy with `--edns` flag:
```
./dnsproxy -u 8.8.8.8:53 --edns
```
Now if you connect to the proxy from the Internet - it will pass through your original IP address's prefix to the upstream server. This way the upstream server may respond with IP addresses of the servers that are located near you to minimize latency.
If you want to use EDNS CS feature when you're connecting to the proxy from a local network, you need to set `--edns-addr=PUBLIC_IP` argument:
```
./dnsproxy -u 8.8.8.8:53 --edns --edns-addr=72.72.72.72
```
Now even if your IP address is 192.168.0.1 and it's not a public IP, the proxy will pass through 72.72.72.72 to the upstream server.
### Bogus NXDomain
This option is similar to dnsmasq `bogus-nxdomain`. `dnsproxy` will transform
responses that contain at least a single IP address which is also specified by
the option into `NXDOMAIN`. Can be specified multiple times.
In the example below, we use AdGuard DNS server that returns `0.0.0.0` for
blocked domains, and transform them to `NXDOMAIN`.
```
./dnsproxy -u 94.140.14.14:53 --bogus-nxdomain=0.0.0.0
```
CIDR ranges are supported as well. The following will respond with `NXDOMAIN`
instead of responses containing any IP from `192.168.0.0`-`192.168.255.255`:
```
./dnsproxy -u 192.168.0.15:53 --bogus-nxdomain=192.168.0.0/16
```
### Basic Auth for DoH
By setting the `--https-userinfo` option you can use `dnsproxy` as a DoH proxy
with basic authentication requirements.
For example:
```sh
./dnsproxy\
--https-port='443'\
--https-userinfo='user:p4ssw0rd'\
--tls-crt='…/my.crt'\
--tls-key='…/my.key'\
-u '94.140.14.14:53'
```
This configuration will only allow DoH queries that contain an `Authorization`
header containing the BasicAuth credentials for user `user` with password
`p4ssw0rd`.
Add `-p 0` if you also want to disable plain-DNS handling and make `dnsproxy`
only serve DoH with Basic Auth checking.
0707010000000C000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001D00000000dnsproxy-0.75.0/bamboo-specs0707010000000D000081A4000000000000000000000001679A649F000009AB000000000000000000000000000000000000002900000000dnsproxy-0.75.0/bamboo-specs/bamboo.yaml---
'version': 2
'plan':
'project-key': 'GO'
'key': 'DNSPROXY'
'name': 'dnsproxy - Build and run tests'
'variables':
'dockerFpm': 'alanfranz/fpm-within-docker:ubuntu-bionic'
# When there is a patch release of Go available, set this property to an
# exact patch version as opposed to a minor one to make sure that this exact
# version is actually used and not whatever the docker daemon on the CI has
# cached a few months ago.
'dockerGo': 'golang:1.23.5'
'maintainer': 'Adguard Go Team'
'name': 'dnsproxy'
'stages':
# TODO(e.burkov): Add separate lint stage for texts.
- 'Lint':
'manual': false
'final': false
'jobs':
- 'Lint'
- 'Test':
'manual': false
'final': false
'jobs':
- 'Test'
'Lint':
'docker':
'image': '${bamboo.dockerGo}'
'volumes':
'${system.GO_CACHE_DIR}': '${bamboo.cacheGo}'
'${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}'
'key': 'LINT'
'other':
'clean-working-dir': true
'requirements':
- 'adg-docker': true
'tasks':
- 'checkout':
'force-clean-build': true
- 'script':
'interpreter': 'SHELL'
'scripts':
- |
set -e -f -u -x
make VERBOSE=1 GOMAXPROCS=1 go-tools go-lint
'Test':
'docker':
'image': '${bamboo.dockerGo}'
'volumes':
'${system.GO_CACHE_DIR}': '${bamboo.cacheGo}'
'${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}'
'key': 'TEST'
'other':
'clean-working-dir': true
'requirements':
- 'adg-docker': true
'tasks':
- 'checkout':
'force-clean-build': true
- 'script':
'interpreter': 'SHELL'
# Projects that have go-bench and/or go-fuzz targets should add them
# here as well.
'scripts':
- |
set -e -f -u -x
make VERBOSE=1 go-deps go-test
'branches':
'create': 'for-pull-request'
'delete':
'after-deleted-days': 1
'after-inactive-days': 5
'link-to-jira': true
'notifications':
- 'events':
- 'plan-status-changed'
'recipients':
- 'webhook':
'name': 'Build webhook'
'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo'
'labels': []
'other':
'concurrent-build-plugin': 'system-default'
0707010000000E000081A4000000000000000000000001679A649F00000230000000000000000000000000000000000000002100000000dnsproxy-0.75.0/config.yaml.dist# This is the yaml configuration file for dnsproxy with minimal working
# configuration, all the options available can be seen with ./dnsproxy --help.
# To use it within dnsproxy specify the --config-path=/<path-to-config.yaml>
# option. Any other command-line options specified will override the values
# from the config file.
---
bootstrap:
- "8.8.8.8:53"
listen-addrs:
- "0.0.0.0"
listen-ports:
- 53
max-go-routines: 0
ratelimit: 0
ratelimit-subnet-len-ipv4: 24
ratelimit-subnet-len-ipv6: 64
udp-buf-size: 0
upstream:
- "1.1.1.1:53"
timeout: '10s'
0707010000000F000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001700000000dnsproxy-0.75.0/docker07070100000010000081A4000000000000000000000001679A649F0000072A000000000000000000000000000000000000002200000000dnsproxy-0.75.0/docker/Dockerfile# A docker file for scripts/make/build-docker.sh.
FROM alpine:3.18
ARG BUILD_DATE
ARG VERSION
ARG VCS_REF
LABEL\
maintainer="AdGuard Team <devteam@adguard.com>" \
org.opencontainers.image.authors="AdGuard Team <devteam@adguard.com>" \
org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.description="Simple DNS proxy with DoH, DoT, DoQ and DNSCrypt support" \
org.opencontainers.image.documentation="https://github.com/AdguardTeam/dnsproxy" \
org.opencontainers.image.licenses="Apache-2.0" \
org.opencontainers.image.revision=$VCS_REF \
org.opencontainers.image.source="https://github.com/AdguardTeam/dnsproxy" \
org.opencontainers.image.title="dnsproxy" \
org.opencontainers.image.url="https://github.com/AdguardTeam/dnsproxy" \
org.opencontainers.image.vendor="AdGuard" \
org.opencontainers.image.version=$VERSION
# Update certificates.
RUN apk --no-cache add ca-certificates libcap tzdata && \
mkdir -p /opt/dnsproxy && chown -R nobody: /opt/dnsproxy
ARG DIST_DIR
ARG TARGETARCH
ARG TARGETOS
ARG TARGETVARIANT
COPY --chown=nobody:nogroup\
./${DIST_DIR}/docker/dnsproxy_${TARGETOS}_${TARGETARCH}_${TARGETVARIANT}\
/opt/dnsproxy/dnsproxy
COPY --chown=nobody:nogroup\
./${DIST_DIR}/docker/config.yaml\
/opt/dnsproxy/config.yaml
RUN setcap 'cap_net_bind_service=+eip' /opt/dnsproxy/dnsproxy
# 53 : TCP, UDP : DNS
# 80 : TCP : HTTP
# 443 : TCP, UDP : HTTPS, DNS-over-HTTPS (incl. HTTP/3), DNSCrypt (main)
# 853 : TCP, UDP : DNS-over-TLS, DNS-over-QUIC
# 5443 : TCP, UDP : DNSCrypt (alt)
# 6060 : TCP : HTTP (pprof)
EXPOSE 53/tcp 53/udp \
80/tcp \
443/tcp 443/udp \
853/tcp 853/udp \
5443/tcp 5443/udp \
6060/tcp
WORKDIR /opt/dnsproxy
ENTRYPOINT ["/opt/dnsproxy/dnsproxy"]
CMD ["--config-path=/opt/dnsproxy/config.yaml"]
07070100000011000081A4000000000000000000000001679A649F00000490000000000000000000000000000000000000002100000000dnsproxy-0.75.0/docker/README.md# DNS Proxy
A simple DNS proxy server that supports all existing DNS protocols including
`DNS-over-TLS`, `DNS-over-HTTPS`, `DNSCrypt`, and `DNS-over-QUIC`. Moreover,
it can work as a `DNS-over-HTTPS`, `DNS-over-TLS` or `DNS-over-QUIC` server.
Learn more about dnsproxy and its full capabilities in
its [Github repo][dnsproxy].
[dnsproxy]: https://github.com/AdguardTeam/dnsproxy
## Quick start
### Pull the Docker image
This command will pull the latest stable version:
```shell
docker pull adguard/dnsproxy
```
### Run the container
Run the container with the default configuration (see `config.yaml.dist` in the
repository) and expose DNS ports.
```shell
docker run --name dnsproxy \
-p 53:53/tcp -p 53:53/udp \
adguard/dnsproxy
```
Run the container with command-line args configuration and expose DNS ports.
```shell
docker run --name dnsproxy_google_dns \
-p 53:53/tcp -p 53:53/udp \
adguard/dnsproxy \
-u 8.8.8.8:53
```
Run the container with a configuration file and expose DNS ports.
```shell
docker run --name dnsproxy_google_dns \
-p 53:53/tcp -p 53:53/udp \
-v $PWD/config.yaml:/opt/dnsproxy/config.yaml \
adguard/dnsproxy
```
07070100000012000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001700000000dnsproxy-0.75.0/fastip07070100000013000081A4000000000000000000000001679A649F00000A54000000000000000000000000000000000000002000000000dnsproxy-0.75.0/fastip/cache.gopackage fastip
import (
"encoding/binary"
"net/netip"
"time"
)
const (
// fastestAddrCacheTTLSec is the cache TTL for IP addresses.
fastestAddrCacheTTLSec = 10 * 60
)
// cacheEntry represents an item that will be stored in the cache.
//
// TODO(e.burkov): Rewrite the cache using zero-values instead of storing
// useless boolean as an integer.
type cacheEntry struct {
// status is 1 if the item is timed out.
status int
latencyMsec uint
}
// packCacheEntry packs the cache entry and the TTL to bytes in the following
// order:
//
// - expire [4]byte (Unix time, seconds),
// - status byte (0 for ok, 1 for timed out),
// - latency [2]byte (milliseconds).
func packCacheEntry(ent *cacheEntry, ttl uint32) (d []byte) {
expire := uint32(time.Now().Unix()) + ttl
d = make([]byte, 4+1+2)
binary.BigEndian.PutUint32(d, expire)
i := 4
d[i] = byte(ent.status)
i++
binary.BigEndian.PutUint16(d[i:], uint16(ent.latencyMsec))
// i += 2
return d
}
// unpackCacheEntry unpacks bytes to cache entry and checks TTL, if the record
// is expired returns nil.
func unpackCacheEntry(data []byte) (ent *cacheEntry) {
now := time.Now().Unix()
expire := binary.BigEndian.Uint32(data[:4])
if int64(expire) <= now {
return nil
}
ent = &cacheEntry{}
i := 4
ent.status = int(data[i])
i++
ent.latencyMsec = uint(binary.BigEndian.Uint16(data[i:]))
// i += 2
return ent
}
// cacheFind finds entry in the cache for the given IP address. Returns nil if
// nothing is found or if the record is expired.
func (f *FastestAddr) cacheFind(ip netip.Addr) (ent *cacheEntry) {
val := f.ipCache.Get(ip.AsSlice())
if val == nil {
return nil
}
return unpackCacheEntry(val)
}
// cacheAddFailure stores unsuccessful attempt in cache.
func (f *FastestAddr) cacheAddFailure(ip netip.Addr) {
ent := cacheEntry{
status: 1,
}
f.ipCacheLock.Lock()
defer f.ipCacheLock.Unlock()
if f.cacheFind(ip) == nil {
f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
}
}
// cacheAddSuccessful stores a successful ping result in the cache. Replaces
// previous result if our latency is lower.
func (f *FastestAddr) cacheAddSuccessful(ip netip.Addr, latency uint) {
ent := cacheEntry{
latencyMsec: latency,
}
f.ipCacheLock.Lock()
defer f.ipCacheLock.Unlock()
entCached := f.cacheFind(ip)
if entCached == nil || entCached.status != 0 || entCached.latencyMsec > latency {
f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
}
}
// cacheAdd adds a new entry to the cache.
func (f *FastestAddr) cacheAdd(ent *cacheEntry, ip netip.Addr, ttl uint32) {
val := packCacheEntry(ent, ttl)
f.ipCache.Set(ip.AsSlice(), val)
}
07070100000014000081A4000000000000000000000001679A649F00000912000000000000000000000000000000000000002500000000dnsproxy-0.75.0/fastip/cache_test.gopackage fastip
import (
"net"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/stretchr/testify/assert"
)
func TestCacheAdd(t *testing.T) {
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
ent := cacheEntry{
status: 0,
latencyMsec: 111,
}
ip := netip.MustParseAddr("1.1.1.1")
f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
// check that it's there
assert.NotNil(t, f.cacheFind(ip))
}
func TestCacheTtl(t *testing.T) {
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
ent := cacheEntry{
status: 0,
latencyMsec: 111,
}
ip := netip.MustParseAddr("1.1.1.1")
f.cacheAdd(&ent, ip, 1)
// check that it's there
assert.NotNil(t, f.cacheFind(ip))
// wait for more than one second
time.Sleep(time.Millisecond * 1001)
// check that now it returns nil
assert.Nil(t, f.cacheFind(ip))
}
func TestCacheAddSuccessfulOverwrite(t *testing.T) {
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
ip := netip.MustParseAddr("1.1.1.1")
f.cacheAddFailure(ip)
// check that it's there
ent := f.cacheFind(ip)
assert.NotNil(t, ent)
assert.Equal(t, 1, ent.status)
// check that it will overwrite existing rec
f.cacheAddSuccessful(ip, 11)
// check that it's there now
ent = f.cacheFind(ip)
assert.NotNil(t, ent)
assert.Equal(t, 0, ent.status)
assert.Equal(t, uint(11), ent.latencyMsec)
}
func TestCacheAddFailureNoOverwrite(t *testing.T) {
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
ip := netip.MustParseAddr("1.1.1.1")
f.cacheAddSuccessful(ip, 11)
// check that it's there
ent := f.cacheFind(ip)
assert.NotNil(t, ent)
assert.Equal(t, 0, ent.status)
// check that it will overwrite existing rec
f.cacheAddFailure(ip)
// check that the old record is still there
ent = f.cacheFind(ip)
assert.NotNil(t, ent)
assert.Equal(t, 0, ent.status)
assert.Equal(t, uint(11), ent.latencyMsec)
}
// TODO(ameshkov): Actually test something.
func TestCache(_ *testing.T) {
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
ent := cacheEntry{
status: 0,
latencyMsec: 111,
}
val := packCacheEntry(&ent, 1)
f.ipCache.Set(net.ParseIP("1.1.1.1").To4(), val)
ent = cacheEntry{
status: 0,
latencyMsec: 222,
}
f.cacheAdd(&ent, netip.MustParseAddr("2.2.2.2"), fastestAddrCacheTTLSec)
}
07070100000015000081A4000000000000000000000001679A649F00001618000000000000000000000000000000000000002200000000dnsproxy-0.75.0/fastip/fastest.go// Package fastip implements the algorithm that allows to query multiple
// resolvers, ping all IP addresses that were returned, and return the fastest
// one among them.
package fastip
import (
"log/slog"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// LogPrefix is a prefix for logging.
const LogPrefix = "fastip"
// DefaultPingWaitTimeout is the default period of time for waiting ping
// operations to finish.
const DefaultPingWaitTimeout = 1 * time.Second
// FastestAddr provides methods to determine the fastest network addresses.
type FastestAddr struct {
// logger is used for logging during the process. It is never nil.
logger *slog.Logger
// pinger is the dialer with predefined timeout for pinging TCP connections.
pinger *net.Dialer
// ipCacheLock protects ipCache.
ipCacheLock *sync.Mutex
// ipCache caches fastest IP addresses.
ipCache cache.Cache
// pingPorts are the ports to ping on.
pingPorts []uint
// pingWaitTimeout is the timeout for waiting all the resolved addresses to
// be pinged. Any ping results received after that moment are cached, but
// won't be used.
pingWaitTimeout time.Duration
}
// NewFastestAddr initializes a new instance of *FastestAddr.
//
// Deprecated: Use [New] instead.
func NewFastestAddr() (f *FastestAddr) {
return &FastestAddr{
logger: slog.Default().With(slogutil.KeyPrefix, LogPrefix),
ipCacheLock: &sync.Mutex{},
ipCache: cache.New(cache.Config{
MaxSize: 64 * 1024,
EnableLRU: true,
}),
pingPorts: []uint{80, 443},
pingWaitTimeout: DefaultPingWaitTimeout,
pinger: &net.Dialer{Timeout: pingTCPTimeout},
}
}
// Config contains all the fields necessary for proxy configuration.
type Config struct {
// Logger is used as the base logger for the service. If nil,
// [slog.Default] with [LogPrefix] is used.
Logger *slog.Logger
// PingWaitTimeout is the timeout for waiting all the resolved addresses to
// be pinged. Any ping results received after that moment are cached, but
// won't be used. If zero, [DefaultPingWaitTimeout] is used.
PingWaitTimeout time.Duration
}
// New initializes a new instance of *FastestAddr.
func New(c *Config) (f *FastestAddr) {
f = &FastestAddr{
ipCacheLock: &sync.Mutex{},
ipCache: cache.New(cache.Config{
MaxSize: 64 * 1024,
EnableLRU: true,
}),
pingPorts: []uint{80, 443},
pinger: &net.Dialer{Timeout: pingTCPTimeout},
}
if c.PingWaitTimeout > 0 {
f.pingWaitTimeout = c.PingWaitTimeout
} else {
f.pingWaitTimeout = DefaultPingWaitTimeout
}
if c.Logger != nil {
f.logger = c.Logger
} else {
f.logger = slog.Default().With(slogutil.KeyPrefix, LogPrefix)
}
return f
}
// ExchangeFastest queries each specified upstream and returns the response with
// the fastest IP address. The fastest IP address is considered to be the first
// one successfully dialed and other addresses are removed from the answer.
func (f *FastestAddr) ExchangeFastest(
req *dns.Msg,
ups []upstream.Upstream,
) (resp *dns.Msg, u upstream.Upstream, err error) {
replies, err := upstream.ExchangeAll(ups, req)
if err != nil {
return nil, nil, err
}
ipSet := container.NewMapSet[netip.Addr]()
for _, r := range replies {
for _, rr := range r.Resp.Answer {
ip := ipFromRR(rr)
if ip.IsValid() && !ip.IsUnspecified() {
ipSet.Add(ip)
}
}
}
ips := ipSet.Values()
host := strings.ToLower(req.Question[0].Name)
if pingRes := f.pingAll(host, ips); pingRes != nil {
return f.prepareReply(pingRes, replies)
}
f.logger.Debug("no fastest ip found, using the first response", "host", host)
return replies[0].Resp, replies[0].Upstream, nil
}
// prepareReply converts replies into the DNS answer message according to res.
// The returned upstream is the one which replied with the fastest address.
func (f *FastestAddr) prepareReply(
res *pingResult,
replies []upstream.ExchangeAllResult,
) (resp *dns.Msg, u upstream.Upstream, err error) {
ip := res.addrPort.Addr()
for _, r := range replies {
if hasInAns(r.Resp, ip) {
resp = r.Resp
u = r.Upstream
break
}
}
if resp == nil {
f.logger.Error("found no replies, most likely this is a bug", "ip", ip)
// TODO(d.kolyshev): Consider returning error?
return replies[0].Resp, replies[0].Upstream, nil
}
filterResponseAnswer(resp, ip)
return resp, u, nil
}
// filterResponseAnswer modifies the response message, it keeps only A and AAAA
// records with the given IP address.
func filterResponseAnswer(resp *dns.Msg, ip netip.Addr) {
ans := make([]dns.RR, 0, len(resp.Answer))
ipBytes := ip.AsSlice()
for _, rr := range resp.Answer {
switch addr := rr.(type) {
case *dns.A:
if addr.A.Equal(ipBytes) {
ans = append(ans, rr)
}
case *dns.AAAA:
if addr.AAAA.Equal(ipBytes) {
ans = append(ans, rr)
}
default:
ans = append(ans, rr)
}
}
// Set new answer.
resp.Answer = ans
}
// hasInAns returns true if m contains ip in its Answer section.
func hasInAns(m *dns.Msg, ip netip.Addr) (ok bool) {
for _, rr := range m.Answer {
respIP := ipFromRR(rr)
if respIP == ip {
return true
}
}
return false
}
// ipFromRR returns the IP address from rr if any.
func ipFromRR(rr dns.RR) (ip netip.Addr) {
switch rr := rr.(type) {
case *dns.A:
ip, _ = netutil.IPToAddr(rr.A, netutil.AddrFamilyIPv4)
case *dns.AAAA:
ip, _ = netutil.IPToAddr(rr.AAAA, netutil.AddrFamilyIPv6)
}
return ip
}
07070100000016000081A4000000000000000000000001679A649F000010DB000000000000000000000000000000000000002700000000dnsproxy-0.75.0/fastip/fastest_test.gopackage fastip
import (
"net/netip"
"testing"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFastestAddr_ExchangeFastest(t *testing.T) {
l := slogutil.NewDiscardLogger()
t.Run("error", func(t *testing.T) {
const errDesired errors.Error = "this is expected"
u := &errUpstream{
err: errDesired,
}
f := New(&Config{
Logger: l,
PingWaitTimeout: DefaultPingWaitTimeout,
})
resp, up, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{u})
require.Error(t, err)
assert.ErrorIs(t, err, errDesired)
assert.Nil(t, resp)
assert.Nil(t, up)
})
t.Run("one_dead", func(t *testing.T) {
port := listen(t, netip.IPv4Unspecified())
f := New(&Config{
Logger: l,
PingWaitTimeout: DefaultPingWaitTimeout,
})
f.pingPorts = []uint{port}
// The alive IP is the just created local listener's address. The dead
// one is known as TEST-NET-1 which shouldn't be routed at all. See
// RFC-5737 (https://datatracker.ietf.org/doc/html/rfc5737).
aliveAddr := netip.MustParseAddr("127.0.0.1")
alive := &testAUpstream{
recs: []*dns.A{newTestRec(t, aliveAddr)},
}
dead := &testAUpstream{
recs: []*dns.A{newTestRec(t, netip.MustParseAddr("192.0.2.1"))},
}
rep, ups, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{dead, alive})
require.NoError(t, err)
assert.Equal(t, ups, alive)
require.NotNil(t, rep)
require.NotEmpty(t, rep.Answer)
require.IsType(t, new(dns.A), rep.Answer[0])
ip := rep.Answer[0].(*dns.A).A
assert.Equal(t, aliveAddr.AsSlice(), []byte(ip))
})
t.Run("all_dead", func(t *testing.T) {
f := New(&Config{
Logger: l,
PingWaitTimeout: DefaultPingWaitTimeout,
})
f.pingPorts = []uint{getFreePort(t)}
firstIP := netip.MustParseAddr("127.0.0.1")
ups := &testAUpstream{
recs: []*dns.A{
newTestRec(t, firstIP),
newTestRec(t, netip.MustParseAddr("127.0.0.2")),
newTestRec(t, netip.MustParseAddr("127.0.0.3")),
},
}
resp, _, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{ups})
require.NoError(t, err)
require.NotNil(t, resp)
require.NotEmpty(t, resp.Answer)
require.IsType(t, new(dns.A), resp.Answer[0])
ip := resp.Answer[0].(*dns.A).A
assert.Equal(t, firstIP.AsSlice(), []byte(ip))
})
}
// testAUpstream is a mock err upstream structure for tests.
type errUpstream struct {
err error
closeErr error
}
// Address implements the [upstream.Upstream] interface for *errUpstream.
func (u *errUpstream) Address() string {
return "bad_upstream"
}
// Exchange implements the [upstream.Upstream] interface for *errUpstream.
func (u *errUpstream) Exchange(_ *dns.Msg) (*dns.Msg, error) {
return nil, u.err
}
// Close implements the [upstream.Upstream] interface for *errUpstream.
func (u *errUpstream) Close() error {
return u.closeErr
}
// testAUpstream is a mock A upstream structure for tests.
type testAUpstream struct {
recs []*dns.A
}
// type check
var _ upstream.Upstream = (*testAUpstream)(nil)
// Exchange implements the [upstream.Upstream] interface for *testAUpstream.
func (u *testAUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
for _, a := range u.recs {
resp.Answer = append(resp.Answer, a)
}
return resp, nil
}
// Address implements the [upstream.Upstream] interface for *testAUpstream.
func (u *testAUpstream) Address() (addr string) {
return ""
}
// Close implements the [upstream.Upstream] interface for *testAUpstream.
func (u *testAUpstream) Close() (err error) {
return nil
}
// newTestRec returns a new test A record.
func newTestRec(t *testing.T, addr netip.Addr) (rr *dns.A) {
return &dns.A{
Hdr: dns.RR_Header{
Rrtype: dns.TypeA,
Name: dns.Fqdn(t.Name()),
Ttl: 60,
},
A: addr.AsSlice(),
}
}
// newTestReq returns a new test A request.
func newTestReq(t *testing.T) (req *dns.Msg) {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: dns.Fqdn(t.Name()),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
}
07070100000017000081A4000000000000000000000001679A649F00000F48000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/fastip/ping.gopackage fastip
import (
"net/netip"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// pingTCPTimeout is a TCP connection timeout. It's higher than pingWaitTimeout
// since the slower connections will be cached anyway.
const pingTCPTimeout = 4 * time.Second
// pingResult is the result of dialing the address.
type pingResult struct {
// addrPort is the address-port pair the result is related to.
addrPort netip.AddrPort
// latency is the duration of dialing process in milliseconds.
latency uint
// success is true when the dialing succeeded.
success bool
}
// schedulePings returns the result with the fastest IP address from the cache,
// if it's found, and starts pinging other IPs which are not cached or outdated.
// Returns scheduled flag which indicates that some goroutines have been
// scheduled.
func (f *FastestAddr) schedulePings(
resCh chan *pingResult,
ips []netip.Addr,
host string,
) (pr *pingResult, scheduled bool) {
for _, ip := range ips {
cached := f.cacheFind(ip)
if cached == nil {
scheduled = true
for _, port := range f.pingPorts {
go f.pingDoTCP(host, netip.AddrPortFrom(ip, uint16(port)), resCh)
}
continue
}
if cached.status == 0 && (pr == nil || cached.latencyMsec < pr.latency) {
pr = &pingResult{
addrPort: netip.AddrPortFrom(ip, 0),
latency: cached.latencyMsec,
success: true,
}
}
}
return pr, scheduled
}
// pingAll pings all ips concurrently and returns as soon as the fastest one is
// found or the timeout is exceeded.
func (f *FastestAddr) pingAll(host string, ips []netip.Addr) (pr *pingResult) {
ipN := len(ips)
switch ipN {
case 0:
return nil
case 1:
return &pingResult{
addrPort: netip.AddrPortFrom(ips[0], 0),
success: true,
}
}
resCh := make(chan *pingResult, ipN*len(f.pingPorts))
pr, scheduled := f.schedulePings(resCh, ips, host)
if !scheduled {
if pr != nil {
f.logger.Debug(
"pinging all returns cached response",
"host", host,
"addr", pr.addrPort,
)
} else {
f.logger.Debug("pinging all returns nothing", "host", host)
}
return pr
}
res := f.firstSuccessRes(resCh, host)
if res == nil {
// In case of timeout return cached or nil.
return pr
}
if pr == nil || res.latency <= pr.latency {
// Cache wasn't found or is worse than res.
return res
}
// Return cached result.
return pr
}
// firstSuccessRes waits and returns the first successful ping result or nil in
// case of timeout.
func (f *FastestAddr) firstSuccessRes(resCh chan *pingResult, host string) (res *pingResult) {
after := time.After(f.pingWaitTimeout)
for {
select {
case res = <-resCh:
f.logger.Debug(
"pinging all got result",
"host", host,
"addr", res.addrPort,
"status", res.success,
)
if !res.success {
continue
}
return res
case <-after:
f.logger.Debug("pinging all timed out", "host", host)
return nil
}
}
}
// pingDoTCP sends the result of dialing the specified address into resCh.
func (f *FastestAddr) pingDoTCP(host string, addrPort netip.AddrPort, resCh chan *pingResult) {
l := f.logger.With("host", host, "addr", addrPort)
l.Debug("open tcp connection")
start := time.Now()
conn, err := f.pinger.Dial(bootstrap.NetworkTCP, addrPort.String())
elapsed := time.Since(start)
success := err == nil
if success {
if cErr := conn.Close(); cErr != nil {
l.Debug("closing tcp connection", slogutil.KeyError, cErr)
}
}
latency := uint(elapsed.Milliseconds())
resCh <- &pingResult{
addrPort: addrPort,
latency: latency,
success: success,
}
addr := addrPort.Addr().Unmap()
if success {
l.Debug("tcp ping success", "elapsed", elapsed)
f.cacheAddSuccessful(addr, latency)
} else {
l.Debug("tcp ping failed to connect", "elapsed", elapsed, slogutil.KeyError, err)
f.cacheAddFailure(addr)
}
}
07070100000018000081A4000000000000000000000001679A649F00001610000000000000000000000000000000000000002400000000dnsproxy-0.75.0/fastip/ping_test.gopackage fastip
import (
"net"
"net/netip"
"runtime"
"sync"
"syscall"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// unit is the convenient alias for struct{}.
type unit = struct{}
func TestFastestAddr_PingAll_timeout(t *testing.T) {
t.Run("isolated", func(t *testing.T) {
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
waitCh := make(chan unit)
f.pinger.Control = func(_, _ string, _ syscall.RawConn) error {
<-waitCh
return nil
}
ip := netutil.IPv4Localhost()
res := f.pingAll("", []netip.Addr{ip, ip})
require.Nil(t, res)
waitCh <- unit{}
})
t.Run("cached", func(t *testing.T) {
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
const lat uint = 42
ip1 := netutil.IPv4Localhost()
ip2 := netip.MustParseAddr("127.0.0.2")
f.cacheAddSuccessful(ip1, lat)
waitCh := make(chan unit)
f.pinger.Control = func(_, _ string, _ syscall.RawConn) error {
<-waitCh
return nil
}
res := f.pingAll("", []netip.Addr{ip1, ip2})
require.NotNil(t, res)
assert.True(t, res.success)
assert.Equal(t, lat, res.latency)
waitCh <- unit{}
})
}
// assertCaching checks the cache of f for containing a connection to ip with
// the specified status.
func assertCaching(t *testing.T, f *FastestAddr, ip netip.Addr, status int) {
t.Helper()
const tickDur = pingTCPTimeout / 16
assert.Eventually(t, func() bool {
ce := f.cacheFind(ip)
return ce != nil && ce.status == status
}, pingTCPTimeout, tickDur)
}
func TestFastestAddr_PingAll_cache(t *testing.T) {
ip := netutil.IPv4Localhost()
t.Run("cached_failed", func(t *testing.T) {
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
f.cacheAddFailure(ip)
res := f.pingAll("", []netip.Addr{ip, ip})
require.Nil(t, res)
})
t.Run("cached_successful", func(t *testing.T) {
const lat uint = 1
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
f.cacheAddSuccessful(ip, lat)
res := f.pingAll("", []netip.Addr{ip, ip})
require.NotNil(t, res)
assert.True(t, res.success)
assert.Equal(t, lat, res.latency)
})
t.Run("not_cached", func(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, listener.Close)
ip = netutil.IPv4Localhost()
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
f.pingPorts = []uint{uint(listener.Addr().(*net.TCPAddr).Port)}
ips := []netip.Addr{ip, ip}
wg := &sync.WaitGroup{}
wg.Add(len(ips) * len(f.pingPorts))
f.pinger.Control = func(_, address string, _ syscall.RawConn) (err error) {
hostport, err := netutil.ParseHostPort(address)
require.NoError(t, err)
assert.Equal(t, ip.String(), hostport.Host)
assert.Contains(t, f.pingPorts, uint(hostport.Port))
wg.Done()
return nil
}
res := f.pingAll("", ips)
require.NotNil(t, res)
assert.True(t, res.success)
assertCaching(t, f, ip, 0)
wg.Wait()
})
}
// listen is a helper function that creates a new listener on ip for t.
func listen(t *testing.T, ip netip.Addr) (port uint) {
t.Helper()
l, err := net.Listen("tcp", netip.AddrPortFrom(ip, 0).String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
return uint(l.Addr().(*net.TCPAddr).Port)
}
func TestFastestAddr_PingAll(t *testing.T) {
ip := netutil.IPv4Localhost()
t.Run("single", func(t *testing.T) {
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
res := f.pingAll("", []netip.Addr{ip})
require.NotNil(t, res)
assert.True(t, res.success)
assert.Equal(t, ip, res.addrPort.Addr())
// There was no ping so the port is zero.
assert.Zero(t, res.addrPort.Port())
// Nothing in the cache since there was no ping.
ce := f.cacheFind(res.addrPort.Addr())
require.Nil(t, ce)
})
t.Run("fastest", func(t *testing.T) {
fastPort := listen(t, ip)
slowPort := listen(t, ip)
ctrlCh := make(chan unit, 1)
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
f.pingPorts = []uint{
fastPort,
slowPort,
}
f.pinger.Control = func(_, address string, _ syscall.RawConn) error {
addrPort := netip.MustParseAddrPort(address)
require.Contains(t, []uint{fastPort, slowPort}, uint(addrPort.Port()))
if addrPort.Port() == uint16(fastPort) {
return nil
}
<-ctrlCh
return nil
}
ips := []netip.Addr{ip, ip}
res := f.pingAll("", ips)
ctrlCh <- unit{}
require.NotNil(t, res)
assert.True(t, res.success)
assert.Equal(t, ip, res.addrPort.Addr())
assert.EqualValues(t, fastPort, res.addrPort.Port())
assertCaching(t, f, ip, 0)
})
t.Run("zero", func(t *testing.T) {
res := New(&Config{Logger: slogutil.NewDiscardLogger()}).pingAll("", nil)
require.Nil(t, res)
})
t.Run("fail", func(t *testing.T) {
port := getFreePort(t)
f := New(&Config{Logger: slogutil.NewDiscardLogger()})
f.pingPorts = []uint{port}
res := f.pingAll("test", []netip.Addr{ip, ip})
require.Nil(t, res)
assertCaching(t, f, ip, 1)
})
}
// getFreePort returns the port number no one listens on.
//
// TODO(e.burkov): The logic is underwhelming. Find a more accurate way.
func getFreePort(t *testing.T) (port uint) {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port = uint(l.Addr().(*net.TCPAddr).Port)
// Stop listening immediately.
require.NoError(t, l.Close())
// Sleeping for some time may be necessary on Windows.
if runtime.GOOS == "windows" {
time.Sleep(100 * time.Millisecond)
}
return port
}
07070100000019000081A4000000000000000000000001679A649F000005C4000000000000000000000000000000000000001700000000dnsproxy-0.75.0/go.modmodule github.com/AdguardTeam/dnsproxy
go 1.23.5
require (
github.com/AdguardTeam/golibs v0.31.0
github.com/ameshkov/dnscrypt/v2 v2.3.0
github.com/ameshkov/dnsstamps v1.0.3
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0
github.com/bluele/gcache v0.0.2
github.com/miekg/dns v1.1.62
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/quic-go/quic-go v0.48.2
github.com/stretchr/testify v1.10.0
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67
golang.org/x/net v0.33.0
golang.org/x/sys v0.28.0
gonum.org/v1/gonum v0.15.1
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect
github.com/kr/text v0.2.0 // indirect
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/onsi/ginkgo/v2 v2.22.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
go.uber.org/mock v0.5.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/mod v0.22.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/tools v0.28.0 // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
)
0707010000001A000081A4000000000000000000000001679A649F00001A2B000000000000000000000000000000000000001700000000dnsproxy-0.75.0/go.sumgithub.com/AdguardTeam/golibs v0.31.0 h1:Z0oPfLTLw6iZmpE58dePy2Bel0MaX+lnDwtFEE5EmIo=
github.com/AdguardTeam/golibs v0.31.0/go.mod h1:wIkZ9o2UnppeW6/YD7yJB71dYbMhiuC1Fh/I2ElW7GQ=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us=
github.com/ameshkov/dnscrypt/v2 v2.3.0 h1:pDXDF7eFa6Lw+04C0hoMh8kCAQM8NwUdFEllSP2zNLs=
github.com/ameshkov/dnscrypt/v2 v2.3.0/go.mod h1:N5hDwgx2cNb4Ay7AhvOSKst+eUiOZ/vbKRO9qMpQttE=
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
github.com/ameshkov/dnsstamps v1.0.3/go.mod h1:Ii3eUu73dx4Vw5O4wjzmT5+lkCwovjzaEZZ4gKyIH5A=
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 h1:0b2vaepXIfMsG++IsjHiI2p4bxALD1Y2nQKGMR5zDQM=
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA=
github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw=
github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg=
github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/onsi/ginkgo/v2 v2.22.1 h1:QW7tbJAUDyVDVOM5dFa7qaybo+CRfR7bemlQUN6Z8aM=
github.com/onsi/ginkgo/v2 v2.22.1/go.mod h1:S6aTpoRsSq2cZOd+pssHAlKW/Q/jZt6cPrPlnj4a1xM=
github.com/onsi/gomega v1.36.1 h1:bJDPBO7ibjxcbHMgSCoo4Yj18UWbKDlLwX1x9sybDcw=
github.com/onsi/gomega v1.36.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c=
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8=
golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw=
gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0=
gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
0707010000001B000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001900000000dnsproxy-0.75.0/internal0707010000001C000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002300000000dnsproxy-0.75.0/internal/bootstrap0707010000001D000081A4000000000000000000000001679A649F00000ECA000000000000000000000000000000000000003000000000dnsproxy-0.75.0/internal/bootstrap/bootstrap.go// Package bootstrap provides types and functions to resolve upstream hostnames
// and to dial retrieved addresses.
package bootstrap
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"net/url"
"slices"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
)
// Network is a network type for use in [Resolver]'s methods.
type Network = string
const (
// NetworkIP is a network type for both address families.
NetworkIP Network = "ip"
// NetworkIP4 is a network type for IPv4 address family.
NetworkIP4 Network = "ip4"
// NetworkIP6 is a network type for IPv6 address family.
NetworkIP6 Network = "ip6"
// NetworkTCP is a network type for TCP connections.
NetworkTCP Network = "tcp"
// NetworkUDP is a network type for UDP connections.
NetworkUDP Network = "udp"
)
// DialHandler is a dial function for creating unencrypted network connections
// to the upstream server. It establishes the connection to the server
// specified at initialization and ignores the addr. network must be one of
// [NetworkTCP] or [NetworkUDP].
type DialHandler func(ctx context.Context, network Network, addr string) (conn net.Conn, err error)
// ResolveDialContext returns a DialHandler that uses addresses resolved from u
// using resolver. l and u must not be nil.
func ResolveDialContext(
u *url.URL,
timeout time.Duration,
r Resolver,
preferV6 bool,
l *slog.Logger,
) (h DialHandler, err error) {
defer func() { err = errors.Annotate(err, "dialing %q: %w", u.Host) }()
host, port, err := netutil.SplitHostPort(u.Host)
if err != nil {
// Don't wrap the error since it's informative enough as is and there is
// already deferred annotation here.
return nil, err
}
if r == nil {
return nil, fmt.Errorf("resolver is nil: %w", ErrNoResolvers)
}
ctx := context.Background()
if timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// TODO(e.burkov): Use network properly, perhaps, pass it through options.
ips, err := r.LookupNetIP(ctx, NetworkIP, host)
if err != nil {
return nil, fmt.Errorf("resolving hostname: %w", err)
}
if preferV6 {
slices.SortStableFunc(ips, netutil.PreferIPv6)
} else {
slices.SortStableFunc(ips, netutil.PreferIPv4)
}
addrs := make([]string, 0, len(ips))
for _, ip := range ips {
addrs = append(addrs, netip.AddrPortFrom(ip, port).String())
}
return NewDialContext(timeout, l, addrs...), nil
}
// NewDialContext returns a DialHandler that dials addrs and returns the first
// successful connection. At least a single addr should be specified. l must
// not be nil.
func NewDialContext(timeout time.Duration, l *slog.Logger, addrs ...string) (h DialHandler) {
addrLen := len(addrs)
if addrLen == 0 {
l.Debug("no addresses to dial")
return func(_ context.Context, _, _ string) (conn net.Conn, err error) {
return nil, errors.Error("no addresses")
}
}
dialer := &net.Dialer{
Timeout: timeout,
}
return func(ctx context.Context, network Network, _ string) (conn net.Conn, err error) {
var errs []error
// Return first succeeded connection. Note that we're using addrs
// instead of what's passed to the function.
for i, addr := range addrs {
a := l.With("addr", addr)
a.DebugContext(ctx, "dialing", "idx", i+1, "total", addrLen)
start := time.Now()
conn, err = dialer.DialContext(ctx, network, addr)
elapsed := time.Since(start)
if err != nil {
a.DebugContext(ctx, "connection failed", "elapsed", elapsed, slogutil.KeyError, err)
errs = append(errs, err)
continue
}
a.DebugContext(ctx, "connection succeeded", "elapsed", elapsed)
return conn, nil
}
return nil, errors.Join(errs...)
}
}
0707010000001E000081A4000000000000000000000001679A649F0000113C000000000000000000000000000000000000003500000000dnsproxy-0.75.0/internal/bootstrap/bootstrap_test.gopackage bootstrap_test
import (
"context"
"net"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testTimeout is a common timeout used in tests of this package.
const testTimeout = 1 * time.Second
// newListener creates a new listener of zero address of the specified network
// type and returns it, adding it's closing to the test cleanup. sig is used to
// send the address of each accepted connection and must be read properly.
func newListener(t testing.TB, network string, sig chan net.Addr) (ipp netip.AddrPort) {
t.Helper()
// TODO(e.burkov): Listen IPv6 as well, when the CI adds IPv6 interfaces.
l, err := net.Listen(network, "127.0.0.1:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
go func() {
pt := testutil.PanicT{}
for c, lerr := l.Accept(); !errors.Is(lerr, net.ErrClosed); c, lerr = l.Accept() {
require.NoError(pt, lerr)
testutil.RequireSend(pt, sig, c.LocalAddr(), testTimeout)
require.NoError(pt, c.Close())
}
}()
ipp, err = netip.ParseAddrPort(l.Addr().String())
require.NoError(t, err)
return ipp
}
// See the details here: https://github.com/AdguardTeam/dnsproxy/issues/18
func TestResolveDialContext(t *testing.T) {
sig := make(chan net.Addr, 1)
ipp := newListener(t, "tcp", sig)
port := ipp.Port()
l := slogutil.NewDiscardLogger()
testCases := []struct {
name string
addresses []netip.Addr
preferIPv6 bool
}{{
name: "v4",
addresses: []netip.Addr{netutil.IPv4Localhost()},
preferIPv6: false,
}, {
name: "both_prefer_v6",
addresses: []netip.Addr{netutil.IPv4Localhost(), netutil.IPv6Localhost()},
preferIPv6: true,
}, {
name: "both_prefer_v4",
addresses: []netip.Addr{netutil.IPv6Localhost(), netutil.IPv4Localhost()},
preferIPv6: false,
}, {
name: "strip_invalid",
addresses: []netip.Addr{{}, netutil.IPv4Localhost(), {}, netutil.IPv6Localhost(), {}},
preferIPv6: true,
}}
const hostname = "host.name"
pt := testutil.PanicT{}
for _, tc := range testCases {
r := &testResolver{
onLookupNetIP: func(
_ context.Context,
network string,
host string,
) (addrs []netip.Addr, err error) {
require.Equal(pt, bootstrap.NetworkIP, network)
require.Equal(pt, hostname, host)
return tc.addresses, nil
},
}
t.Run(tc.name, func(t *testing.T) {
dialContext, err := bootstrap.ResolveDialContext(
&url.URL{Host: netutil.JoinHostPort(hostname, port)},
testTimeout,
bootstrap.ParallelResolver{r},
tc.preferIPv6,
l,
)
require.NoError(t, err)
conn, err := dialContext(context.Background(), bootstrap.NetworkTCP, "")
require.NoError(t, err)
expected, ok := testutil.RequireReceive(t, sig, testTimeout)
require.True(t, ok)
assert.Equal(t, expected.String(), conn.RemoteAddr().String())
})
}
t.Run("no_addresses", func(t *testing.T) {
r := &testResolver{
onLookupNetIP: func(
_ context.Context,
network string,
host string,
) (addrs []netip.Addr, err error) {
require.Equal(pt, bootstrap.NetworkIP, network)
require.Equal(pt, hostname, host)
return nil, nil
},
}
dialContext, err := bootstrap.ResolveDialContext(
&url.URL{Host: netutil.JoinHostPort(hostname, port)},
testTimeout,
bootstrap.ParallelResolver{r},
false,
l,
)
require.NoError(t, err)
_, err = dialContext(context.Background(), bootstrap.NetworkTCP, "")
testutil.AssertErrorMsg(t, "no addresses", err)
})
t.Run("bad_hostname", func(t *testing.T) {
const errMsg = `dialing "bad hostname": address bad hostname: ` +
`missing port in address`
dialContext, err := bootstrap.ResolveDialContext(
&url.URL{Host: "bad hostname"},
testTimeout,
nil,
false,
l,
)
testutil.AssertErrorMsg(t, errMsg, err)
assert.Nil(t, dialContext)
})
t.Run("no_resolvers", func(t *testing.T) {
dialContext, err := bootstrap.ResolveDialContext(
&url.URL{Host: netutil.JoinHostPort(hostname, port)},
testTimeout,
nil,
false,
l,
)
assert.ErrorIs(t, err, bootstrap.ErrNoResolvers)
assert.Nil(t, dialContext)
})
}
0707010000001F000081A4000000000000000000000001679A649F000000BC000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/bootstrap/error.gopackage bootstrap
import "github.com/AdguardTeam/golibs/errors"
// ErrNoResolvers is returned when zero resolvers specified.
const ErrNoResolvers errors.Error = "no resolvers specified"
07070100000020000081A4000000000000000000000001679A649F00000FE3000000000000000000000000000000000000002F00000000dnsproxy-0.75.0/internal/bootstrap/resolver.gopackage bootstrap
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"slices"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// Resolver resolves the hostnames to IP addresses. Note, that [net.Resolver]
// from standard library also implements this interface.
type Resolver interface {
// LookupNetIP looks up the IP addresses for the given host. network should
// be one of [NetworkIP], [NetworkIP4] or [NetworkIP6]. The response may be
// empty even if err is nil. All the addrs must be valid.
LookupNetIP(ctx context.Context, network Network, host string) (addrs []netip.Addr, err error)
}
// type check
var _ Resolver = &net.Resolver{}
// ParallelResolver is a slice of resolvers that are queried concurrently. The
// first successful response is returned.
type ParallelResolver []Resolver
// type check
var _ Resolver = ParallelResolver(nil)
// LookupNetIP implements the [Resolver] interface for ParallelResolver.
func (r ParallelResolver) LookupNetIP(
ctx context.Context,
network Network,
host string,
) (addrs []netip.Addr, err error) {
resolversNum := len(r)
switch resolversNum {
case 0:
return nil, ErrNoResolvers
case 1:
return r[0].LookupNetIP(ctx, network, host)
default:
// Go on.
}
// Size of channel must accommodate results of lookups from all resolvers,
// sending into channel will block otherwise.
ch := make(chan any, resolversNum)
for _, rslv := range r {
go lookupAsync(ctx, rslv, network, host, ch)
}
var errs []error
for range r {
switch result := <-ch; result := result.(type) {
case error:
errs = append(errs, result)
case []netip.Addr:
return result, nil
}
}
return nil, errors.Join(errs...)
}
// recoverAndLog is a deferred helper that recovers from a panic and logs the
// panic value with the logger from context or with a default logger. Sends the
// recovered value into resCh.
//
// TODO(a.garipov): Move this helper to golibs.
func recoverAndLog(ctx context.Context, resCh chan<- any) {
v := recover()
if v == nil {
return
}
err, ok := v.(error)
if !ok {
err = fmt.Errorf("error value: %v", v)
}
l, ok := slogutil.LoggerFromContext(ctx)
if !ok {
l = slog.Default()
}
l.ErrorContext(ctx, "recovered panic", slogutil.KeyError, err)
slogutil.PrintStack(ctx, l, slog.LevelError)
resCh <- err
}
// lookupAsync performs a lookup for ip of host with r and sends the result into
// resCh. It is intended to be used as a goroutine.
func lookupAsync(ctx context.Context, r Resolver, network, host string, resCh chan<- any) {
// TODO(d.kolyshev): Propose better solution to recover without requiring
// logger in the context.
defer recoverAndLog(ctx, resCh)
addrs, err := r.LookupNetIP(ctx, network, host)
if err != nil {
resCh <- err
} else {
resCh <- addrs
}
}
// ConsequentResolver is a slice of resolvers that are queried in order until
// the first successful non-empty response, as opposed to just successful
// response requirement in [ParallelResolver].
type ConsequentResolver []Resolver
// type check
var _ Resolver = ConsequentResolver(nil)
// LookupNetIP implements the [Resolver] interface for ConsequentResolver.
func (resolvers ConsequentResolver) LookupNetIP(
ctx context.Context,
network Network,
host string,
) (addrs []netip.Addr, err error) {
if len(resolvers) == 0 {
return nil, ErrNoResolvers
}
var errs []error
for _, r := range resolvers {
addrs, err = r.LookupNetIP(ctx, network, host)
if err == nil && len(addrs) > 0 {
return addrs, nil
}
errs = append(errs, err)
}
return nil, errors.Join(errs...)
}
// StaticResolver is a resolver which always responds with an underlying slice
// of IP addresses regardless of host and network.
type StaticResolver []netip.Addr
// type check
var _ Resolver = StaticResolver(nil)
// LookupNetIP implements the [Resolver] interface for StaticResolver.
func (r StaticResolver) LookupNetIP(
_ context.Context,
_ Network,
_ string,
) (addrs []netip.Addr, err error) {
return slices.Clone(r), nil
}
07070100000021000081A4000000000000000000000001679A649F00000A9A000000000000000000000000000000000000003400000000dnsproxy-0.75.0/internal/bootstrap/resolver_test.gopackage bootstrap_test
import (
"context"
"net/netip"
"strings"
"testing"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testResolver is the [Resolver] interface implementation for testing purposes.
type testResolver struct {
onLookupNetIP func(ctx context.Context, network, host string) (addrs []netip.Addr, err error)
}
// LookupNetIP implements the [Resolver] interface for *testResolver.
func (r *testResolver) LookupNetIP(
ctx context.Context,
network string,
host string,
) (addrs []netip.Addr, err error) {
return r.onLookupNetIP(ctx, network, host)
}
func TestLookupParallel(t *testing.T) {
const hostname = "host.name"
t.Run("no_resolvers", func(t *testing.T) {
addrs, err := bootstrap.ParallelResolver(nil).LookupNetIP(context.Background(), "ip", "")
assert.ErrorIs(t, err, bootstrap.ErrNoResolvers)
assert.Nil(t, addrs)
})
pt := testutil.PanicT{}
hostAddrs := []netip.Addr{netutil.IPv4Localhost()}
immediate := &testResolver{
onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
require.Equal(pt, hostname, host)
require.Equal(pt, "ip", network)
return hostAddrs, nil
},
}
t.Run("one_resolver", func(t *testing.T) {
addrs, err := bootstrap.ParallelResolver{immediate}.LookupNetIP(
context.Background(),
"ip",
hostname,
)
require.NoError(t, err)
assert.Equal(t, hostAddrs, addrs)
})
t.Run("two_resolvers", func(t *testing.T) {
delayCh := make(chan struct{}, 1)
delayed := &testResolver{
onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
require.Equal(pt, hostname, host)
require.Equal(pt, "ip", network)
testutil.RequireReceive(pt, delayCh, testTimeout)
return []netip.Addr{netutil.IPv6Localhost()}, nil
},
}
addrs, err := bootstrap.ParallelResolver{immediate, delayed}.LookupNetIP(
context.Background(),
"ip",
hostname,
)
require.NoError(t, err)
testutil.RequireSend(t, delayCh, struct{}{}, testTimeout)
assert.Equal(t, hostAddrs, addrs)
})
t.Run("all_errors", func(t *testing.T) {
err := assert.AnError
errStr := err.Error()
wantErrMsg := strings.Join([]string{errStr, errStr, errStr}, "\n")
r := &testResolver{
onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, assert.AnError
},
}
addrs, err := bootstrap.ParallelResolver{r, r, r}.LookupNetIP(
context.Background(),
"ip",
hostname,
)
testutil.AssertErrorMsg(t, wantErrMsg, err)
assert.Nil(t, addrs)
})
}
07070100000022000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001D00000000dnsproxy-0.75.0/internal/cmd07070100000023000081A4000000000000000000000001679A649F000041D2000000000000000000000000000000000000002500000000dnsproxy-0.75.0/internal/cmd/args.gopackage cmd
import (
"flag"
"fmt"
"io"
"os"
"slices"
"strings"
"github.com/AdguardTeam/dnsproxy/internal/version"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/timeutil"
)
// Indexes to help with the [commandLineOptions] initialization.
const (
configPathIdx = iota
logOutputIdx
tlsCertPathIdx
tlsKeyPathIdx
httpsServerNameIdx
httpsUserinfoIdx
dnsCryptConfigPathIdx
ednsAddrIdx
upstreamModeIdx
listenAddrsIdx
listenPortsIdx
httpsListenPortsIdx
tlsListenPortsIdx
quicListenPortsIdx
dnsCryptListenPortsIdx
upstreamsIdx
bootstrapDNSIdx
fallbacksIdx
privateRDNSUpstreamsIdx
dns64PrefixIdx
privateSubnetsIdx
bogusNXDomainIdx
hostsFilesIdx
timeoutIdx
cacheMinTTLIdx
cacheMaxTTLIdx
cacheSizeBytesIdx
ratelimitIdx
ratelimitSubnetLenIPv4Idx
ratelimitSubnetLenIPv6Idx
udpBufferSizeIdx
maxGoRoutinesIdx
tlsMinVersionIdx
tlsMaxVersionIdx
helpIdx
hostsFileEnabledIdx
pprofIdx
versionIdx
verboseIdx
insecureIdx
ipv6DisabledIdx
http3Idx
cacheOptimisticIdx
cacheIdx
refuseAnyIdx
enableEDNSSubnetIdx
dns64Idx
usePrivateRDNSIdx
)
// commandLineOption contains information about a command-line option: its long
// and, if there is one, short forms, the value type, and the description.
type commandLineOption struct {
description string
long string
short string
valueType string
}
// commandLineOptions are all command-line options currently supported by the
// binary.
var commandLineOptions = []*commandLineOption{
configPathIdx: {
description: "YAML configuration file. Minimal working configuration in config.yaml.dist." +
" Options passed through command line will override the ones from this file.",
long: "config-path",
short: "",
valueType: "path",
},
logOutputIdx: {
description: `Path to the log file.`,
long: "output",
short: "o",
valueType: "path",
},
tlsCertPathIdx: {
description: "Path to a file with the certificate chain.",
long: "tls-crt",
short: "c",
valueType: "path",
},
tlsKeyPathIdx: {
description: "Path to a file with the private key.",
long: "tls-key",
short: "k",
valueType: "path",
},
httpsServerNameIdx: {
description: "Set the Server header for the responses from the HTTPS server.",
long: "https-server-name",
short: "",
valueType: "name",
},
httpsUserinfoIdx: {
description: "If set, all DoH queries are required to have this basic authentication " +
"information.",
long: "https-userinfo",
short: "",
valueType: "name",
},
dnsCryptConfigPathIdx: {
description: "Path to a file with DNSCrypt configuration. You can generate one using " +
"https://github.com/ameshkov/dnscrypt.",
long: "dnscrypt-config",
short: "g",
valueType: "path",
},
ednsAddrIdx: {
description: "Send EDNS Client Address.",
long: "edns-addr",
short: "",
valueType: "address",
},
upstreamModeIdx: {
description: "Defines the upstreams logic mode, possible values: load_balance, parallel, " +
"fastest_addr (default: load_balance).",
long: "upstream-mode",
short: "",
valueType: "mode",
},
listenAddrsIdx: {
description: "Listening addresses.",
long: "listen",
short: "l",
valueType: "address",
},
listenPortsIdx: {
description: "Listening ports. Zero value disables TCP and UDP listeners.",
long: "port",
short: "p",
valueType: "port",
},
httpsListenPortsIdx: {
description: "Listening ports for DNS-over-HTTPS.",
long: "https-port",
short: "s",
valueType: "port",
},
tlsListenPortsIdx: {
description: "Listening ports for DNS-over-TLS.",
long: "tls-port",
short: "t",
valueType: "port",
},
quicListenPortsIdx: {
description: "Listening ports for DNS-over-QUIC.",
long: "quic-port",
short: "q",
valueType: "port",
},
dnsCryptListenPortsIdx: {
description: "Listening ports for DNSCrypt.",
long: "dnscrypt-port",
short: "y",
valueType: "port",
},
upstreamsIdx: {
description: "An upstream to be used (can be specified multiple times). You can also " +
"specify path to a file with the list of servers.",
long: "upstream",
short: "u",
valueType: "",
},
bootstrapDNSIdx: {
description: "Bootstrap DNS for DoH and DoT, can be specified multiple times (default: " +
"use system-provided).",
long: "bootstrap",
short: "b",
valueType: "",
},
fallbacksIdx: {
description: "Fallback resolvers to use when regular ones are unavailable, can be " +
"specified multiple times. You can also specify path to a file with the list of servers.",
long: "fallback",
short: "f",
valueType: "",
},
privateRDNSUpstreamsIdx: {
description: "Private DNS upstreams to use for reverse DNS lookups of private addresses, " +
"can be specified multiple times.",
long: "private-rdns-upstream",
short: "",
valueType: "",
},
dns64PrefixIdx: {
description: "Prefix used to handle DNS64. If not specified, dnsproxy uses the " +
"'Well-Known Prefix' 64:ff9b::. Can be specified multiple times.",
long: "dns64-prefix",
short: "",
valueType: "subnet",
},
privateSubnetsIdx: {
description: "Private subnets to use for reverse DNS lookups of private addresses.",
long: "private-subnets",
short: "",
valueType: "subnet",
},
bogusNXDomainIdx: {
description: "Transform the responses containing at least a single IP that matches " +
"specified addresses and CIDRs into NXDOMAIN. Can be specified multiple times.",
long: "bogus-nxdomain",
short: "",
valueType: "subnet",
},
hostsFilesIdx: {
description: "List of paths to the hosts files, can be specified multiple times.",
long: "hosts-files",
short: "",
valueType: "path",
},
timeoutIdx: {
description: "Timeout for outbound DNS queries to remote upstream servers in a " +
"human-readable form",
long: "timeout",
short: "",
valueType: "duration",
},
cacheMinTTLIdx: {
description: "Minimum TTL value for DNS entries, in seconds. Capped at 3600. " +
"Artificially extending TTLs should only be done with careful consideration.",
long: "cache-min-ttl",
short: "",
valueType: "uint32",
},
cacheMaxTTLIdx: {
description: "Maximum TTL value for DNS entries, in seconds.",
long: "cache-max-ttl",
short: "",
valueType: "uint32",
},
cacheSizeBytesIdx: {
description: "Cache size (in bytes). Default: 64k.",
long: "cache-size",
short: "",
valueType: "int",
},
ratelimitIdx: {
description: "Ratelimit (requests per second).",
long: "ratelimit",
short: "r",
valueType: "int",
},
ratelimitSubnetLenIPv4Idx: {
description: "Ratelimit subnet length for IPv4.",
long: "ratelimit-subnet-len-ipv4",
short: "",
valueType: "int",
},
ratelimitSubnetLenIPv6Idx: {
description: "Ratelimit subnet length for IPv6.",
long: "ratelimit-subnet-len-ipv6",
short: "",
valueType: "int",
},
udpBufferSizeIdx: {
description: "Set the size of the UDP buffer in bytes. A value <= 0 will use the system " +
"default.",
long: "udp-buf-size",
short: "",
valueType: "int",
},
maxGoRoutinesIdx: {
description: "Set the maximum number of go routines. A zero value will not not set a " +
"maximum.",
long: "max-go-routines",
short: "",
valueType: "uint",
},
tlsMinVersionIdx: {
description: "Minimum TLS version, for example 1.0.",
long: "tls-min-version",
short: "",
valueType: "version",
},
tlsMaxVersionIdx: {
description: "Maximum TLS version, for example 1.3.",
long: "tls-max-version",
short: "",
valueType: "version",
},
helpIdx: {
description: "Print this help message and quit.",
long: "help",
short: "h",
valueType: "",
},
hostsFileEnabledIdx: {
description: "If specified, use hosts files for resolving.",
long: "hosts-file-enabled",
short: "",
valueType: "",
},
pprofIdx: {
description: "If present, exposes pprof information on localhost:6060.",
long: "pprof",
short: "",
valueType: "",
},
versionIdx: {
description: "Prints the program version.",
long: "version",
short: "",
valueType: "",
},
verboseIdx: {
description: "Verbose output.",
long: "verbose",
short: "v",
valueType: "",
},
insecureIdx: {
description: "Disable secure TLS certificate validation.",
long: "insecure",
short: "",
valueType: "",
},
ipv6DisabledIdx: {
description: "If specified, all AAAA requests will be replied with NoError RCode and " +
"empty answer.",
long: "ipv6-disabled",
short: "",
valueType: "",
},
http3Idx: {
description: "Enable HTTP/3 support.",
long: "http3",
short: "",
valueType: "",
},
cacheOptimisticIdx: {
description: "If specified, optimistic DNS cache is enabled.",
long: "cache-optimistic",
short: "",
valueType: "",
},
cacheIdx: {
description: "If specified, DNS cache is enabled.",
long: "cache",
short: "",
valueType: "",
},
refuseAnyIdx: {
description: "If specified, refuses ANY requests.",
long: "refuse-any",
short: "",
valueType: "",
},
enableEDNSSubnetIdx: {
description: "Use EDNS Client Subnet extension.",
long: "edns",
short: "",
valueType: "",
},
dns64Idx: {
description: "If specified, dnsproxy will act as a DNS64 server.",
long: "dns64",
short: "",
valueType: "",
},
usePrivateRDNSIdx: {
description: "If specified, use private upstreams for reverse DNS lookups of private " +
"addresses.",
long: "use-private-rdns",
short: "",
valueType: "",
},
}
// parseCmdLineOptions parses the command-line options. conf must not be nil.
func parseCmdLineOptions(conf *configuration) (err error) {
cmdName, args := os.Args[0], os.Args[1:]
flags := flag.NewFlagSet(cmdName, flag.ContinueOnError)
for i, fieldPtr := range []any{
configPathIdx: &conf.ConfigPath,
logOutputIdx: &conf.LogOutput,
tlsCertPathIdx: &conf.TLSCertPath,
tlsKeyPathIdx: &conf.TLSKeyPath,
httpsServerNameIdx: &conf.HTTPSServerName,
httpsUserinfoIdx: &conf.HTTPSUserinfo,
dnsCryptConfigPathIdx: &conf.DNSCryptConfigPath,
ednsAddrIdx: &conf.EDNSAddr,
upstreamModeIdx: &conf.UpstreamMode,
listenAddrsIdx: &conf.ListenAddrs,
listenPortsIdx: &conf.ListenPorts,
httpsListenPortsIdx: &conf.HTTPSListenPorts,
tlsListenPortsIdx: &conf.TLSListenPorts,
quicListenPortsIdx: &conf.QUICListenPorts,
dnsCryptListenPortsIdx: &conf.DNSCryptListenPorts,
upstreamsIdx: &conf.Upstreams,
bootstrapDNSIdx: &conf.BootstrapDNS,
fallbacksIdx: &conf.Fallbacks,
privateRDNSUpstreamsIdx: &conf.PrivateRDNSUpstreams,
dns64PrefixIdx: &conf.DNS64Prefix,
privateSubnetsIdx: &conf.PrivateSubnets,
bogusNXDomainIdx: &conf.BogusNXDomain,
hostsFilesIdx: &conf.HostsFiles,
timeoutIdx: &conf.Timeout,
cacheMinTTLIdx: &conf.CacheMinTTL,
cacheMaxTTLIdx: &conf.CacheMaxTTL,
cacheSizeBytesIdx: &conf.CacheSizeBytes,
ratelimitIdx: &conf.Ratelimit,
ratelimitSubnetLenIPv4Idx: &conf.RatelimitSubnetLenIPv4,
ratelimitSubnetLenIPv6Idx: &conf.RatelimitSubnetLenIPv6,
udpBufferSizeIdx: &conf.UDPBufferSize,
maxGoRoutinesIdx: &conf.MaxGoRoutines,
tlsMinVersionIdx: &conf.TLSMinVersion,
tlsMaxVersionIdx: &conf.TLSMaxVersion,
helpIdx: &conf.help,
hostsFileEnabledIdx: &conf.HostsFileEnabled,
pprofIdx: &conf.Pprof,
versionIdx: &conf.Version,
verboseIdx: &conf.Verbose,
insecureIdx: &conf.Insecure,
ipv6DisabledIdx: &conf.IPv6Disabled,
http3Idx: &conf.HTTP3,
cacheOptimisticIdx: &conf.CacheOptimistic,
cacheIdx: &conf.Cache,
refuseAnyIdx: &conf.RefuseAny,
enableEDNSSubnetIdx: &conf.EnableEDNSSubnet,
dns64Idx: &conf.DNS64,
usePrivateRDNSIdx: &conf.UsePrivateRDNS,
} {
addOption(flags, fieldPtr, commandLineOptions[i])
}
flags.Usage = func() { usage(cmdName, os.Stderr) }
err = flags.Parse(args)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
nonFlags := flags.Args()
if len(nonFlags) > 0 {
return fmt.Errorf("positional arguments are not allowed, please check your command line "+
"arguments; detected positional arguments: %s", nonFlags)
}
return nil
}
// defineFlag defines a flag with specified setFlag function. o must not be
// nil.
func defineFlag[T any](
fieldPtr *T,
o *commandLineOption,
setFlag func(p *T, name string, value T, usage string),
) {
setFlag(fieldPtr, o.long, *fieldPtr, o.description)
if o.short != "" {
setFlag(fieldPtr, o.short, *fieldPtr, o.description)
}
}
// defineFlagVar defines a flag with the specified [flag.Value] value. o must
// not be nil.
func defineFlagVar(flags *flag.FlagSet, value flag.Value, o *commandLineOption) {
flags.Var(value, o.long, o.description)
if o.short != "" {
flags.Var(value, o.short, o.description)
}
}
// defineTimeutilDurationFlag defines a flag with for the specified
// [*timeutil.Duration] pointer and command line option. o must not be nil.
func defineTimeutilDurationFlag(
flags *flag.FlagSet,
fieldPtr *timeutil.Duration,
o *commandLineOption,
) {
flags.TextVar(fieldPtr, o.long, *fieldPtr, o.description)
if o.short != "" {
flags.TextVar(fieldPtr, o.short, *fieldPtr, o.description)
}
}
// addOption adds the command-line option described by o to flags using fieldPtr
// as the pointer to the value.
func addOption(flags *flag.FlagSet, fieldPtr any, o *commandLineOption) {
switch fieldPtr := fieldPtr.(type) {
case *string:
defineFlag(fieldPtr, o, flags.StringVar)
case *bool:
defineFlag(fieldPtr, o, flags.BoolVar)
case *int:
defineFlag(fieldPtr, o, flags.IntVar)
case *uint:
defineFlag(fieldPtr, o, flags.UintVar)
case *uint32:
defineFlagVar(flags, (*uint32Value)(fieldPtr), o)
case *float32:
defineFlagVar(flags, (*float32Value)(fieldPtr), o)
case *[]int:
defineFlagVar(flags, newIntSliceValue(fieldPtr), o)
case *[]string:
defineFlagVar(flags, newStringSliceValue(fieldPtr), o)
case *timeutil.Duration:
defineTimeutilDurationFlag(flags, fieldPtr, o)
default:
panic(fmt.Errorf("unexpected field pointer type %T: %w", fieldPtr, errors.ErrBadEnumValue))
}
}
// usage prints a usage message similar to the one printed by package flag but
// taking long vs. short versions into account as well as using more informative
// value hints.
func usage(cmdName string, output io.Writer) {
options := slices.Clone(commandLineOptions)
slices.SortStableFunc(options, func(a, b *commandLineOption) (res int) {
return strings.Compare(a.long, b.long)
})
b := &strings.Builder{}
_, _ = fmt.Fprintf(b, "Usage of %s:\n", cmdName)
for _, o := range options {
writeUsageLine(b, o)
// Use four spaces before the tab to trigger good alignment for both 4-
// and 8-space tab stops.
_, _ = fmt.Fprintf(b, " \t%s\n", o.description)
}
_, _ = io.WriteString(output, b.String())
}
// writeUsageLine writes the usage line for the provided command-line option.
func writeUsageLine(b *strings.Builder, o *commandLineOption) {
if o.short == "" {
if o.valueType == "" {
_, _ = fmt.Fprintf(b, " --%s\n", o.long)
} else {
_, _ = fmt.Fprintf(b, " --%s=%s\n", o.long, o.valueType)
}
return
}
if o.valueType == "" {
_, _ = fmt.Fprintf(b, " --%s/-%s\n", o.long, o.short)
} else {
_, _ = fmt.Fprintf(b, " --%[1]s=%[3]s/-%[2]s %[3]s\n", o.long, o.short, o.valueType)
}
}
// processCmdLineOptions decides if dnsproxy should exit depending on the
// results of command-line option parsing.
func processCmdLineOptions(conf *configuration, parseErr error) (exitCode int, needExit bool) {
if parseErr != nil {
// Assume that usage has already been printed.
return osutil.ExitCodeArgumentError, true
}
if conf.help {
usage(os.Args[0], os.Stdout)
return osutil.ExitCodeSuccess, true
}
if conf.Version {
fmt.Printf("dnsproxy version %s\n", version.Version())
return osutil.ExitCodeSuccess, true
}
return osutil.ExitCodeSuccess, false
}
07070100000024000081A4000000000000000000000001679A649F0000109E000000000000000000000000000000000000002400000000dnsproxy-0.75.0/internal/cmd/cmd.go// Package cmd is the dnsproxy CLI entry point.
package cmd
import (
"context"
"fmt"
"log/slog"
"net/http"
"net/http/pprof"
"os"
"os/signal"
"syscall"
"time"
"github.com/AdguardTeam/dnsproxy/internal/version"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil"
)
// Main is the entrypoint of dnsproxy CLI. Main may accept arguments, such as
// embedded assets and command-line arguments.
func Main() {
conf, exitCode, err := parseConfig()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("parsing options: %w", err))
}
if conf == nil {
os.Exit(exitCode)
}
logOutput := os.Stdout
if conf.LogOutput != "" {
// #nosec G302 -- Trust the file path that is given in the
// configuration.
logOutput, err = os.OpenFile(conf.LogOutput, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("cannot create a log file: %s", err))
os.Exit(osutil.ExitCodeArgumentError)
}
defer func() { _ = logOutput.Close() }()
}
lvl := slog.LevelInfo
if conf.Verbose {
lvl = slog.LevelDebug
}
l := slogutil.New(&slogutil.Config{
Output: logOutput,
Format: slogutil.FormatDefault,
Level: lvl,
// TODO(d.kolyshev): Consider making configurable.
AddTimestamp: true,
})
ctx := context.Background()
if conf.Pprof {
runPprof(l)
}
err = runProxy(ctx, l, conf)
if err != nil {
l.ErrorContext(ctx, "running dnsproxy", slogutil.KeyError, err)
// As defers are skipped in case of os.Exit, close logOutput manually.
//
// TODO(a.garipov): Consider making logger.Close method.
if logOutput != os.Stdout {
_ = logOutput.Close()
}
os.Exit(osutil.ExitCodeFailure)
}
}
// runProxy starts and runs the proxy. l must not be nil.
//
// TODO(e.burkov): Move into separate dnssvc package.
func runProxy(ctx context.Context, l *slog.Logger, conf *configuration) (err error) {
var (
buildVersion = version.Version()
revision = version.Revision()
branch = version.Branch()
commitTime = version.CommitTime()
)
l.InfoContext(
ctx,
"dnsproxy starting",
"version", buildVersion,
"revision", revision,
"branch", branch,
"commit_time", commitTime,
)
// Prepare the proxy server and its configuration.
proxyConf, err := createProxyConfig(ctx, l, conf)
if err != nil {
return fmt.Errorf("configuring proxy: %w", err)
}
dnsProxy, err := proxy.New(proxyConf)
if err != nil {
return fmt.Errorf("creating proxy: %w", err)
}
// Start the proxy server.
err = dnsProxy.Start(ctx)
if err != nil {
return fmt.Errorf("starting dnsproxy: %w", err)
}
// TODO(e.burkov): Use [service.SignalHandler].
signalChannel := make(chan os.Signal, 1)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM)
<-signalChannel
// Stopping the proxy.
err = dnsProxy.Shutdown(ctx)
if err != nil {
return fmt.Errorf("stopping dnsproxy: %w", err)
}
return nil
}
// runPprof runs pprof server on localhost:6060.
//
// TODO(e.burkov): Use [httputil.RoutePprof].
func runPprof(l *slog.Logger) {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
go func() {
// TODO(d.kolyshev): Consider making configurable.
pprofAddr := "localhost:6060"
l.Info("starting pprof", "addr", pprofAddr)
srv := &http.Server{
Addr: pprofAddr,
ReadTimeout: 60 * time.Second,
Handler: mux,
}
err := srv.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
l.Error("pprof failed to listen %v", "addr", pprofAddr, slogutil.KeyError, err)
}
}()
}
07070100000025000081A4000000000000000000000001679A649F0000215A000000000000000000000000000000000000002700000000dnsproxy-0.75.0/internal/cmd/config.gopackage cmd
import (
"fmt"
"os"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/timeutil"
"gopkg.in/yaml.v3"
)
// configuration represents dnsproxy configuration.
type configuration struct {
// ConfigPath is the path to the configuration file.
ConfigPath string
// LogOutput is the path to the log file.
LogOutput string `yaml:"output"`
// TLSCertPath is the path to the .crt with the certificate chain.
TLSCertPath string `yaml:"tls-crt"`
// TLSKeyPath is the path to the file with the private key.
TLSKeyPath string `yaml:"tls-key"`
// HTTPSServerName sets Server header for the HTTPS server.
HTTPSServerName string `yaml:"https-server-name"`
// HTTPSUserinfo is the sole permitted userinfo for the DoH basic
// authentication. If it is set, all DoH queries are required to have this
// basic authentication information.
HTTPSUserinfo string `yaml:"https-userinfo"`
// DNSCryptConfigPath is the path to the DNSCrypt configuration file.
DNSCryptConfigPath string `yaml:"dnscrypt-config"`
// EDNSAddr is the custom EDNS Client Address to send.
EDNSAddr string `yaml:"edns-addr"`
// UpstreamMode determines the logic through which upstreams will be used.
// If not specified the [proxy.UpstreamModeLoadBalance] is used.
UpstreamMode string `yaml:"upstream-mode"`
// ListenAddrs is the list of server's listen addresses.
ListenAddrs []string `yaml:"listen-addrs"`
// ListenPorts are the ports server listens on.
ListenPorts []int `yaml:"listen-ports"`
// HTTPSListenPorts are the ports server listens on for DNS-over-HTTPS.
HTTPSListenPorts []int `yaml:"https-port"`
// TLSListenPorts are the ports server listens on for DNS-over-TLS.
TLSListenPorts []int `yaml:"tls-port"`
// QUICListenPorts are the ports server listens on for DNS-over-QUIC.
QUICListenPorts []int `yaml:"quic-port"`
// DNSCryptListenPorts are the ports server listens on for DNSCrypt.
DNSCryptListenPorts []int `yaml:"dnscrypt-port"`
// Upstreams is the list of DNS upstream servers.
Upstreams []string `yaml:"upstream"`
// BootstrapDNS is the list of bootstrap DNS upstream servers.
BootstrapDNS []string `yaml:"bootstrap"`
// Fallbacks is the list of fallback DNS upstream servers.
Fallbacks []string `yaml:"fallback"`
// PrivateRDNSUpstreams are upstreams to use for reverse DNS lookups of
// private addresses, including the requests for authority records, such as
// SOA and NS.
PrivateRDNSUpstreams []string `yaml:"private-rdns-upstream"`
// DNS64Prefix defines the DNS64 prefixes that dnsproxy should use when it
// acts as a DNS64 server. If not specified, dnsproxy uses the default
// Well-Known Prefix. This option can be specified multiple times.
DNS64Prefix []string `yaml:"dns64-prefix"`
// PrivateSubnets is the list of private subnets to determine private
// addresses.
PrivateSubnets []string `yaml:"private-subnets"`
// BogusNXDomain transforms responses that contain at least one of the given
// IP addresses into NXDOMAIN.
//
// TODO(a.garipov): Find a way to use [netutil.Prefix]. Currently, package
// go-flags doesn't support text unmarshalers.
BogusNXDomain []string `yaml:"bogus-nxdomain"`
// HostsFiles is the list of paths to the hosts files to resolve from.
HostsFiles []string `yaml:"hosts-files"`
// Timeout for outbound DNS queries to remote upstream servers in a
// human-readable form. Default is 10s.
Timeout timeutil.Duration `yaml:"timeout"`
// CacheMinTTL is the minimum TTL value for caching DNS entries, in seconds.
// It overrides the TTL value from the upstream server, if the one is less.
CacheMinTTL uint32 `yaml:"cache-min-ttl"`
// CacheMaxTTL is the maximum TTL value for caching DNS entries, in seconds.
// It overrides the TTL value from the upstream server, if the one is
// greater.
CacheMaxTTL uint32 `yaml:"cache-max-ttl"`
// CacheSizeBytes is the cache size in bytes. Default is 64k.
CacheSizeBytes int `yaml:"cache-size"`
// Ratelimit is the maximum number of requests per second.
Ratelimit int `yaml:"ratelimit"`
// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
// rate limiting requests.
RatelimitSubnetLenIPv4 int `yaml:"ratelimit-subnet-len-ipv4"`
// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
// rate limiting requests.
RatelimitSubnetLenIPv6 int `yaml:"ratelimit-subnet-len-ipv6"`
// UDPBufferSize is the size of the UDP buffer in bytes. A value <= 0 will
// use the system default.
UDPBufferSize int `yaml:"udp-buf-size"`
// MaxGoRoutines is the maximum number of goroutines.
MaxGoRoutines uint `yaml:"max-go-routines"`
// TLSMinVersion is the minimum allowed version of TLS.
//
// TODO(d.kolyshev): Use more suitable type.
TLSMinVersion float32 `yaml:"tls-min-version"`
// TLSMaxVersion is the maximum allowed version of TLS.
//
// TODO(d.kolyshev): Use more suitable type.
TLSMaxVersion float32 `yaml:"tls-max-version"`
// help, if true, prints the command-line option help message and quit with
// a successful exit-code.
help bool
// HostsFileEnabled controls whether hosts files are used for resolving or
// not.
HostsFileEnabled bool `yaml:"hosts-file-enabled"`
// Pprof defines whether the pprof information needs to be exposed via
// localhost:6060 or not.
Pprof bool `yaml:"pprof"`
// Version, if true, prints the program version, and exits.
Version bool `yaml:"version"`
// Verbose controls the verbosity of the output.
Verbose bool `yaml:"verbose"`
// Insecure disables upstream servers TLS certificate verification.
Insecure bool `yaml:"insecure"`
// IPv6Disabled makes the server to respond with NODATA to all AAAA queries.
IPv6Disabled bool `yaml:"ipv6-disabled"`
// HTTP3 controls whether HTTP/3 is enabled for this instance of dnsproxy.
// It enables HTTP/3 support for both the DoH upstreams and the DoH server.
HTTP3 bool `yaml:"http3"`
// CacheOptimistic, if set to true, enables the optimistic DNS cache. That
// means that cached results will be served even if their cache TTL has
// already expired.
CacheOptimistic bool `yaml:"cache-optimistic"`
// Cache controls whether DNS responses are cached or not.
Cache bool `yaml:"cache"`
// RefuseAny makes the server to refuse requests of type ANY.
RefuseAny bool `yaml:"refuse-any"`
// EnableEDNSSubnet uses EDNS Client Subnet extension.
EnableEDNSSubnet bool `yaml:"edns"`
// DNS64 defines whether DNS64 functionality is enabled or not.
DNS64 bool `yaml:"dns64"`
// UsePrivateRDNS makes the server to use private upstreams for reverse DNS
// lookups of private addresses, including the requests for authority
// records, such as SOA and NS.
UsePrivateRDNS bool `yaml:"use-private-rdns"`
}
// parseConfig returns options parsed from the command args or config file. If
// no options have been parsed, it returns a suitable exit code and an error.
func parseConfig() (conf *configuration, exitCode int, err error) {
conf = &configuration{
HTTPSServerName: "dnsproxy",
UpstreamMode: string(proxy.UpstreamModeLoadBalance),
CacheSizeBytes: 64 * 1024,
Timeout: timeutil.Duration(10 * time.Second),
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 56,
HostsFileEnabled: true,
}
err = parseCmdLineOptions(conf)
exitCode, needExit := processCmdLineOptions(conf, err)
if needExit {
return nil, exitCode, err
}
confPath := conf.ConfigPath
if confPath == "" {
return conf, exitCode, nil
}
// TODO(d.kolyshev): Bootstrap and use slog.
fmt.Printf("dnsproxy config path: %s\n", confPath)
err = parseConfigFile(conf, confPath)
if err != nil {
return nil, osutil.ExitCodeFailure, fmt.Errorf(
"parsing config file %s: %w",
confPath,
err,
)
}
// Parse command-line args again as it has priority over YAML config.
err = parseCmdLineOptions(conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, osutil.ExitCodeFailure, err
}
return conf, exitCode, nil
}
// parseConfigFile fills options with the settings from file read by the given
// path.
func parseConfigFile(conf *configuration, confPath string) (err error) {
// #nosec G304 -- Trust the file path that is given in the args.
b, err := os.ReadFile(confPath)
if err != nil {
return fmt.Errorf("reading file: %w", err)
}
err = yaml.Unmarshal(b, conf)
if err != nil {
return fmt.Errorf("unmarshalling file: %w", err)
}
return nil
}
07070100000026000081A4000000000000000000000001679A649F00000E99000000000000000000000000000000000000002500000000dnsproxy-0.75.0/internal/cmd/flag.gopackage cmd
import (
"flag"
"fmt"
"strconv"
"strings"
"github.com/AdguardTeam/golibs/stringutil"
)
// uint32Value is an uint32 that can be defined as a flag for [flag.FlagSet].
type uint32Value uint32
// type check
var _ flag.Value = (*uint32Value)(nil)
// Set implements the [flag.Value] interface for *uint32Value.
func (i *uint32Value) Set(s string) (err error) {
v, err := strconv.ParseUint(s, 0, 32)
*i = uint32Value(v)
return err
}
// String implements the [flag.Value] interface for *uint32Value.
func (i *uint32Value) String() (out string) {
return strconv.FormatUint(uint64(*i), 10)
}
// float32Value is an float32 that can be defined as a flag for [flag.FlagSet].
type float32Value float32
// type check
var _ flag.Value = (*float32Value)(nil)
// Set implements the [flag.Value] interface for *float32Value.
func (i *float32Value) Set(s string) (err error) {
v, err := strconv.ParseFloat(s, 32)
*i = float32Value(v)
return err
}
// String implements the [flag.Value] interface for *float32Value.
func (i *float32Value) String() (out string) {
return strconv.FormatFloat(float64(*i), 'f', 3, 32)
}
// intSliceValue represent a struct with a slice of integers that can be defined
// as a flag for [flag.FlagSet].
type intSliceValue struct {
// values is the pointer to a slice of integers to store parsed values.
values *[]int
// isSet is false until the corresponding flag is met for the first time.
// When the flag is found, the default value is overwritten with zero value.
isSet bool
}
// newIntSliceValue returns a pointer to intSliceValue with the given value.
func newIntSliceValue(p *[]int) (out *intSliceValue) {
return &intSliceValue{
values: p,
isSet: false,
}
}
// type check
var _ flag.Value = (*intSliceValue)(nil)
// Set implements the [flag.Value] interface for *intSliceValue.
func (i *intSliceValue) Set(s string) (err error) {
v, err := strconv.Atoi(s)
if err != nil {
return fmt.Errorf("parsing integer slice arg %q: %w", s, err)
}
if !i.isSet {
i.isSet = true
*i.values = []int{}
}
*i.values = append(*i.values, v)
return nil
}
// String implements the [flag.Value] interface for *intSliceValue.
func (i *intSliceValue) String() (out string) {
if i == nil || i.values == nil {
return ""
}
sb := &strings.Builder{}
for idx, v := range *i.values {
if idx > 0 {
stringutil.WriteToBuilder(sb, ",")
}
stringutil.WriteToBuilder(sb, strconv.Itoa(v))
}
return sb.String()
}
// stringSliceValue represent a struct with a slice of strings that can be
// defined as a flag for [flag.FlagSet].
type stringSliceValue struct {
// values is the pointer to a slice of string to store parsed values.
values *[]string
// isSet is false until the corresponding flag is met for the first time.
// When the flag is found, the default value is overwritten with zero value.
isSet bool
}
// newStringSliceValue returns a pointer to stringSliceValue with the given
// value.
func newStringSliceValue(p *[]string) (out *stringSliceValue) {
return &stringSliceValue{
values: p,
isSet: false,
}
}
// type check
var _ flag.Value = (*stringSliceValue)(nil)
// Set implements the [flag.Value] interface for *stringSliceValue.
func (i *stringSliceValue) Set(s string) (err error) {
if !i.isSet {
i.isSet = true
*i.values = []string{}
}
*i.values = append(*i.values, s)
return nil
}
// String implements the [flag.Value] interface for *stringSliceValue.
func (i *stringSliceValue) String() (out string) {
if i == nil || i.values == nil {
return ""
}
sb := &strings.Builder{}
for idx, v := range *i.values {
if idx > 0 {
stringutil.WriteToBuilder(sb, ",")
}
stringutil.WriteToBuilder(sb, v)
}
return sb.String()
}
07070100000027000081A4000000000000000000000001679A649F00003AA8000000000000000000000000000000000000002600000000dnsproxy-0.75.0/internal/cmd/proxy.gopackage cmd
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"net/netip"
"net/url"
"os"
"strings"
"time"
"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
"github.com/AdguardTeam/dnsproxy/internal/handler"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/ameshkov/dnscrypt/v2"
"gopkg.in/yaml.v3"
)
// TODO(e.burkov): Use a separate type for the YAML configuration file.
// createProxyConfig initializes [proxy.Config]. l must not be nil.
func createProxyConfig(
ctx context.Context,
l *slog.Logger,
conf *configuration,
) (proxyConf *proxy.Config, err error) {
hostsFiles, err := conf.hostsFiles(ctx, l)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
hosts, err := handler.ReadHosts(hostsFiles)
if err != nil {
return nil, fmt.Errorf("reading hosts files: %w", err)
}
reqHdlr := handler.NewDefault(&handler.DefaultConfig{
Logger: l.With(slogutil.KeyPrefix, "default_handler"),
// TODO(e.burkov): Use the configured message constructor.
MessageConstructor: dnsmsg.DefaultMessageConstructor{},
HaltIPv6: conf.IPv6Disabled,
HostsFiles: hosts,
})
proxyConf = &proxy.Config{
Logger: l.With(slogutil.KeyPrefix, proxy.LogPrefix),
RatelimitSubnetLenIPv4: conf.RatelimitSubnetLenIPv4,
RatelimitSubnetLenIPv6: conf.RatelimitSubnetLenIPv6,
Ratelimit: conf.Ratelimit,
CacheEnabled: conf.Cache,
CacheSizeBytes: conf.CacheSizeBytes,
CacheMinTTL: conf.CacheMinTTL,
CacheMaxTTL: conf.CacheMaxTTL,
CacheOptimistic: conf.CacheOptimistic,
RefuseAny: conf.RefuseAny,
HTTP3: conf.HTTP3,
// TODO(e.burkov): The following CIDRs are aimed to match any address.
// This is not quite proper approach to be used by default so think
// about configuring it.
TrustedProxies: netutil.SliceSubnetSet{
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("::0/0"),
},
EnableEDNSClientSubnet: conf.EnableEDNSSubnet,
UDPBufferSize: conf.UDPBufferSize,
HTTPSServerName: conf.HTTPSServerName,
MaxGoroutines: conf.MaxGoRoutines,
UsePrivateRDNS: conf.UsePrivateRDNS,
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
RequestHandler: reqHdlr.HandleRequest,
}
if uiStr := conf.HTTPSUserinfo; uiStr != "" {
user, pass, ok := strings.Cut(uiStr, ":")
if ok {
proxyConf.Userinfo = url.UserPassword(user, pass)
} else {
proxyConf.Userinfo = url.User(user)
}
}
conf.initBogusNXDomain(ctx, l, proxyConf)
var errs []error
errs = append(errs, conf.initUpstreams(ctx, l, proxyConf))
errs = append(errs, conf.initEDNS(ctx, l, proxyConf))
errs = append(errs, conf.initTLSConfig(proxyConf))
errs = append(errs, conf.initDNSCryptConfig(proxyConf))
errs = append(errs, conf.initListenAddrs(proxyConf))
errs = append(errs, conf.initSubnets(proxyConf))
return proxyConf, errors.Join(errs...)
}
// isEmpty returns false if uc contains at least a single upstream. uc must not
// be nil.
//
// TODO(e.burkov): Think of a better way to validate the config. Perhaps,
// return an error from [ParseUpstreamsConfig] if no upstreams were initialized.
func isEmpty(uc *proxy.UpstreamConfig) (ok bool) {
return len(uc.Upstreams) == 0 &&
len(uc.DomainReservedUpstreams) == 0 &&
len(uc.SpecifiedDomainUpstreams) == 0
}
// defaultLocalTimeout is the default timeout for local operations.
const defaultLocalTimeout = 1 * time.Second
// initUpstreams inits upstream-related config fields.
//
// TODO(d.kolyshev): Join errors.
func (conf *configuration) initUpstreams(
ctx context.Context,
l *slog.Logger,
config *proxy.Config,
) (err error) {
httpVersions := upstream.DefaultHTTPVersions
if conf.HTTP3 {
httpVersions = []upstream.HTTPVersion{
upstream.HTTPVersion3,
upstream.HTTPVersion2,
upstream.HTTPVersion11,
}
}
timeout := time.Duration(conf.Timeout)
bootOpts := &upstream.Options{
Logger: l,
HTTPVersions: httpVersions,
InsecureSkipVerify: conf.Insecure,
Timeout: timeout,
}
boot, err := initBootstrap(ctx, l, conf.BootstrapDNS, bootOpts)
if err != nil {
return fmt.Errorf("initializing bootstrap: %w", err)
}
upsOpts := &upstream.Options{
Logger: l,
HTTPVersions: httpVersions,
InsecureSkipVerify: conf.Insecure,
Bootstrap: boot,
Timeout: timeout,
}
upstreams := loadServersList(conf.Upstreams)
config.UpstreamConfig, err = proxy.ParseUpstreamsConfig(upstreams, upsOpts)
if err != nil {
return fmt.Errorf("parsing upstreams configuration: %w", err)
}
privateUpsOpts := &upstream.Options{
Logger: l,
HTTPVersions: httpVersions,
Bootstrap: boot,
Timeout: min(defaultLocalTimeout, timeout),
}
privateUpstreams := loadServersList(conf.PrivateRDNSUpstreams)
private, err := proxy.ParseUpstreamsConfig(privateUpstreams, privateUpsOpts)
if err != nil {
return fmt.Errorf("parsing private rdns upstreams configuration: %w", err)
}
if !isEmpty(private) {
config.PrivateRDNSUpstreamConfig = private
}
fallbackUpstreams := loadServersList(conf.Fallbacks)
fallbacks, err := proxy.ParseUpstreamsConfig(fallbackUpstreams, upsOpts)
if err != nil {
return fmt.Errorf("parsing fallback upstreams configuration: %w", err)
}
if !isEmpty(fallbacks) {
config.Fallbacks = fallbacks
}
if conf.UpstreamMode != "" {
err = config.UpstreamMode.UnmarshalText([]byte(conf.UpstreamMode))
if err != nil {
return fmt.Errorf("parsing upstream mode: %w", err)
}
return nil
}
config.UpstreamMode = proxy.UpstreamModeLoadBalance
return nil
}
// initBootstrap initializes the [upstream.Resolver] for bootstrapping upstream
// servers. It returns the default resolver if no bootstraps were specified.
// The returned resolver will also use system hosts files first.
func initBootstrap(
ctx context.Context,
l *slog.Logger,
bootstraps []string,
opts *upstream.Options,
) (r upstream.Resolver, err error) {
var resolvers []upstream.Resolver
for i, b := range bootstraps {
var ur *upstream.UpstreamResolver
ur, err = upstream.NewUpstreamResolver(b, opts)
if err != nil {
return nil, fmt.Errorf("creating bootstrap resolver at index %d: %w", i, err)
}
resolvers = append(resolvers, upstream.NewCachingResolver(ur))
}
switch len(resolvers) {
case 0:
etcHosts, hostsErr := upstream.NewDefaultHostsResolver(osutil.RootDirFS(), l)
if hostsErr != nil {
l.ErrorContext(ctx, "creating default hosts resolver", slogutil.KeyError, hostsErr)
return net.DefaultResolver, nil
}
return upstream.ConsequentResolver{etcHosts, net.DefaultResolver}, nil
case 1:
return resolvers[0], nil
default:
return upstream.ParallelResolver(resolvers), nil
}
}
// initEDNS inits EDNS-related config fields.
func (conf *configuration) initEDNS(
ctx context.Context,
l *slog.Logger,
config *proxy.Config,
) (err error) {
if conf.EDNSAddr == "" {
return nil
}
if !conf.EnableEDNSSubnet {
l.WarnContext(ctx, "--edns is required", "--edns-addr", conf.EDNSAddr)
return nil
}
config.EDNSAddr, err = netutil.ParseIP(conf.EDNSAddr)
if err != nil {
return fmt.Errorf("parsing edns-addr: %w", err)
}
return nil
}
// initBogusNXDomain inits BogusNXDomain structure.
func (conf *configuration) initBogusNXDomain(ctx context.Context, l *slog.Logger, config *proxy.Config) {
if len(conf.BogusNXDomain) == 0 {
return
}
for i, s := range conf.BogusNXDomain {
p, err := proxynetutil.ParseSubnet(s)
if err != nil {
// TODO(a.garipov): Consider returning this err as a proper error.
l.WarnContext(ctx, "parsing bogus nxdomain", "index", i, slogutil.KeyError, err)
} else {
config.BogusNXDomain = append(config.BogusNXDomain, p)
}
}
}
// initTLSConfig inits the TLS config.
func (conf *configuration) initTLSConfig(config *proxy.Config) (err error) {
if conf.TLSCertPath != "" && conf.TLSKeyPath != "" {
var tlsConfig *tls.Config
tlsConfig, err = newTLSConfig(conf)
if err != nil {
return fmt.Errorf("loading TLS config: %w", err)
}
config.TLSConfig = tlsConfig
}
return nil
}
// initDNSCryptConfig inits the DNSCrypt config.
func (conf *configuration) initDNSCryptConfig(config *proxy.Config) (err error) {
if conf.DNSCryptConfigPath == "" {
return
}
b, err := os.ReadFile(conf.DNSCryptConfigPath)
if err != nil {
return fmt.Errorf("reading DNSCrypt config %q: %w", conf.DNSCryptConfigPath, err)
}
rc := &dnscrypt.ResolverConfig{}
err = yaml.Unmarshal(b, rc)
if err != nil {
return fmt.Errorf("unmarshalling DNSCrypt config: %w", err)
}
cert, err := rc.CreateCert()
if err != nil {
return fmt.Errorf("creating DNSCrypt certificate: %w", err)
}
config.DNSCryptResolverCert = cert
config.DNSCryptProviderName = rc.ProviderName
return nil
}
// parseListenAddrs returns a slice of listen IP addresses from the given
// options. In case no addresses are specified by options returns a slice with
// the IPv4 unspecified address "0.0.0.0".
//
// TODO(d.kolyshev): Join errors.
func parseListenAddrs(addrStrs []string) (addrs []netip.Addr, err error) {
for i, a := range addrStrs {
var ip netip.Addr
ip, err = netip.ParseAddr(a)
if err != nil {
return addrs, fmt.Errorf("parsing listen address at index %d: %s", i, a)
}
addrs = append(addrs, ip)
}
if len(addrs) == 0 {
// If ListenAddrs has not been parsed through config file nor command
// line we set it to "0.0.0.0".
//
// TODO(a.garipov): Consider using localhost.
addrs = append(addrs, netip.IPv4Unspecified())
}
return addrs, nil
}
// initListenAddrs sets up proxy configuration listen IP addresses.
func (conf *configuration) initListenAddrs(config *proxy.Config) (err error) {
addrs, err := parseListenAddrs(conf.ListenAddrs)
if err != nil {
return fmt.Errorf("parsing listen addresses: %w", err)
}
if len(conf.ListenPorts) == 0 {
// If ListenPorts has not been parsed through config file nor command
// line we set it to 53.
conf.ListenPorts = []int{53}
}
for _, port := range conf.ListenPorts {
for _, ip := range addrs {
addrPort := netip.AddrPortFrom(ip, uint16(port))
config.UDPListenAddr = append(config.UDPListenAddr, net.UDPAddrFromAddrPort(addrPort))
config.TCPListenAddr = append(config.TCPListenAddr, net.TCPAddrFromAddrPort(addrPort))
}
}
initTLSListenAddrs(config, conf, addrs)
initDNSCryptListenAddrs(config, conf, addrs)
return nil
}
// initTLSListenAddrs sets up proxy configuration TLS listen addresses.
func initTLSListenAddrs(proxyConf *proxy.Config, conf *configuration, addrs []netip.Addr) {
if proxyConf.TLSConfig == nil {
return
}
for _, ip := range addrs {
for _, port := range conf.TLSListenPorts {
a := net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port)))
proxyConf.TLSListenAddr = append(proxyConf.TLSListenAddr, a)
}
for _, port := range conf.HTTPSListenPorts {
a := net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port)))
proxyConf.HTTPSListenAddr = append(proxyConf.HTTPSListenAddr, a)
}
for _, port := range conf.QUICListenPorts {
a := net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port)))
proxyConf.QUICListenAddr = append(proxyConf.QUICListenAddr, a)
}
}
}
// initDNSCryptListenAddrs sets up proxy configuration DNSCrypt listen
// addresses.
func initDNSCryptListenAddrs(proxyConf *proxy.Config, conf *configuration, addrs []netip.Addr) {
if proxyConf.DNSCryptResolverCert == nil || proxyConf.DNSCryptProviderName == "" {
return
}
for _, port := range conf.DNSCryptListenPorts {
p := uint16(port)
for _, ip := range addrs {
addrPort := netip.AddrPortFrom(ip, p)
tcp := net.TCPAddrFromAddrPort(addrPort)
proxyConf.DNSCryptTCPListenAddr = append(proxyConf.DNSCryptTCPListenAddr, tcp)
udp := net.UDPAddrFromAddrPort(addrPort)
proxyConf.DNSCryptUDPListenAddr = append(proxyConf.DNSCryptUDPListenAddr, udp)
}
}
}
// initSubnets sets the DNS64 configuration into conf.
//
// TODO(d.kolyshev): Join errors.
func (conf *configuration) initSubnets(proxyConf *proxy.Config) (err error) {
if proxyConf.UseDNS64 = conf.DNS64; proxyConf.UseDNS64 {
for i, p := range conf.DNS64Prefix {
var pref netip.Prefix
pref, err = netip.ParsePrefix(p)
if err != nil {
return fmt.Errorf("parsing dns64 prefix at index %d: %w", i, err)
}
proxyConf.DNS64Prefs = append(proxyConf.DNS64Prefs, pref)
}
}
if !conf.UsePrivateRDNS {
return nil
}
return conf.initPrivateSubnets(proxyConf)
}
// initSubnets sets the private subnets configuration into conf.
func (conf *configuration) initPrivateSubnets(proxyConf *proxy.Config) (err error) {
private := make([]netip.Prefix, 0, len(conf.PrivateSubnets))
for i, p := range conf.PrivateSubnets {
var pref netip.Prefix
pref, err = netip.ParsePrefix(p)
if err != nil {
return fmt.Errorf("parsing private subnet at index %d: %w", i, err)
}
private = append(private, pref)
}
if len(private) > 0 {
proxyConf.PrivateSubnets = netutil.SliceSubnetSet(private)
}
return nil
}
// loadServersList loads a list of DNS servers from the specified list. The
// thing is that the user may specify either a server address or the path to a
// file with a list of addresses. This method takes care of it, it reads the
// file and loads servers from this file if needed.
func loadServersList(sources []string) []string {
var servers []string
for _, source := range sources {
// #nosec G304 -- Trust the file path that is given in the
// configuration.
data, err := os.ReadFile(source)
if err != nil {
// Ignore errors, just consider it a server address and not a file.
servers = append(servers, source)
}
lines := strings.Split(string(data), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// Ignore comments in the file.
if line == "" ||
strings.HasPrefix(line, "!") ||
strings.HasPrefix(line, "#") {
continue
}
servers = append(servers, line)
}
}
return servers
}
// hostsFiles returns the list of hosts files to resolve from. It's empty if
// resolving from hosts files is disabled.
func (conf *configuration) hostsFiles(ctx context.Context, l *slog.Logger) (paths []string, err error) {
if !conf.HostsFileEnabled {
l.DebugContext(ctx, "hosts files are disabled")
return nil, nil
}
l.DebugContext(ctx, "hosts files are enabled")
if len(conf.HostsFiles) > 0 {
return conf.HostsFiles, nil
}
paths, err = proxynetutil.DefaultHostsPaths()
if err != nil {
return nil, fmt.Errorf("getting default hosts files: %w", err)
}
l.DebugContext(ctx, "hosts files are not specified, using default", "paths", paths)
return paths, nil
}
07070100000028000081A4000000000000000000000001679A649F0000079F000000000000000000000000000000000000002400000000dnsproxy-0.75.0/internal/cmd/tls.gopackage cmd
import (
"crypto/tls"
"fmt"
"os"
)
// NewTLSConfig returns the TLS config that includes a certificate. Use it for
// server TLS configuration or for a client certificate. If caPath is empty,
// system CAs will be used.
func newTLSConfig(conf *configuration) (c *tls.Config, err error) {
// Set default TLS min/max versions
tlsMinVersion := tls.VersionTLS10
tlsMaxVersion := tls.VersionTLS13
switch conf.TLSMinVersion {
case 1.1:
tlsMinVersion = tls.VersionTLS11
case 1.2:
tlsMinVersion = tls.VersionTLS12
case 1.3:
tlsMinVersion = tls.VersionTLS13
}
switch conf.TLSMaxVersion {
case 1.0:
tlsMaxVersion = tls.VersionTLS10
case 1.1:
tlsMaxVersion = tls.VersionTLS11
case 1.2:
tlsMaxVersion = tls.VersionTLS12
}
cert, err := loadX509KeyPair(conf.TLSCertPath, conf.TLSKeyPath)
if err != nil {
return nil, fmt.Errorf("loading TLS cert: %s", err)
}
// #nosec G402 -- TLS MinVersion is configured by user.
return &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: uint16(tlsMinVersion),
MaxVersion: uint16(tlsMaxVersion),
}, nil
}
// loadX509KeyPair reads and parses a public/private key pair from a pair of
// files. The files must contain PEM encoded data. The certificate file may
// contain intermediate certificates following the leaf certificate to form a
// certificate chain. On successful return, Certificate.Leaf will be nil
// because the parsed form of the certificate is not retained.
func loadX509KeyPair(certFile, keyFile string) (crt tls.Certificate, err error) {
// #nosec G304 -- Trust the file path that is given in the configuration.
certPEMBlock, err := os.ReadFile(certFile)
if err != nil {
return tls.Certificate{}, err
}
// #nosec G304 -- Trust the file path that is given in the configuration.
keyPEMBlock, err := os.ReadFile(keyFile)
if err != nil {
return tls.Certificate{}, err
}
return tls.X509KeyPair(certPEMBlock, keyPEMBlock)
}
07070100000029000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002000000000dnsproxy-0.75.0/internal/dnsmsg0707010000002A000081A4000000000000000000000001679A649F00000D61000000000000000000000000000000000000002F00000000dnsproxy-0.75.0/internal/dnsmsg/constructor.go// Package dnsmsg contains common constants, functions, and types for inspecting
// and constructing DNS messages.
package dnsmsg
import (
"strings"
"github.com/miekg/dns"
)
// MessageConstructor creates DNS messages.
type MessageConstructor interface {
// NewMsgNXDOMAIN creates a new response message replying to req with the
// NXDOMAIN code.
NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg)
// NewMsgSERVFAIL creates a new response message replying to req with the
// SERVFAIL code.
NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg)
// NewMsgNOTIMPLEMENTED creates a new response message replying to req with
// the NOTIMPLEMENTED code.
NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg)
// NewMsgNODATA creates a new empty response message replying to req with
// the NOERROR code.
//
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
NewMsgNODATA(req *dns.Msg) (resp *dns.Msg)
}
// DefaultMessageConstructor is a default implementation of
// [MessageConstructor].
type DefaultMessageConstructor struct{}
// type check
var _ MessageConstructor = DefaultMessageConstructor{}
// NewMsgNXDOMAIN implements the [MessageConstructor] interface for
// DefaultMessageConstructor.
func (DefaultMessageConstructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
return reply(req, dns.RcodeNameError)
}
// NewMsgSERVFAIL implements the [MessageConstructor] interface for
// DefaultMessageConstructor.
func (DefaultMessageConstructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
return reply(req, dns.RcodeServerFailure)
}
// NewMsgNOTIMPLEMENTED implements the [MessageConstructor] interface for
// DefaultMessageConstructor.
func (DefaultMessageConstructor) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
resp = reply(req, dns.RcodeNotImplemented)
// Most of the Internet and especially the inner core has an MTU of at least
// 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet
// is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)).
//
// See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17.
const maxUDPPayload = 1452
// NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so
// explicitly set it.
resp.SetEdns0(maxUDPPayload, false)
return resp
}
// NewMsgNODATA implements the [MessageConstructor] interface for
// DefaultMessageConstructor.
func (DefaultMessageConstructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
resp = reply(req, dns.RcodeSuccess)
zone := req.Question[0].Name
soa := &dns.SOA{
// Values copied from verisign's nonexistent .com domain.
//
// Their exact values are not important in our use case because they are
// used for domain transfers between primary/secondary DNS servers.
Refresh: 1800,
Retry: 60,
Expire: 604800,
Minttl: 86400,
// copied from AdGuard DNS
Ns: "fake-for-negative-caching.adguard.com.",
Serial: 100500,
Mbox: "hostmaster.",
// rest is request-specific
Hdr: dns.RR_Header{
Name: zone,
Rrtype: dns.TypeSOA,
Ttl: 10,
Class: dns.ClassINET,
},
}
if !strings.HasPrefix(zone, ".") {
soa.Mbox += zone
}
resp.Ns = append(resp.Ns, soa)
return resp
}
// reply creates a new response message replying to req with the given code.
func reply(req *dns.Msg, code int) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, code)
resp.RecursionAvailable = true
return resp
}
0707010000002B000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002600000000dnsproxy-0.75.0/internal/dnsproxytest0707010000002C000081A4000000000000000000000001679A649F0000006A000000000000000000000000000000000000003600000000dnsproxy-0.75.0/internal/dnsproxytest/dnsproxytest.go// Package dnsproxytest provides a set of test utilities for the dnsproxy
// module.
package dnsproxytest
0707010000002D000081A4000000000000000000000001679A649F00000AA4000000000000000000000000000000000000003300000000dnsproxy-0.75.0/internal/dnsproxytest/interface.gopackage dnsproxytest
import (
"github.com/miekg/dns"
)
// FakeUpstream is a fake [Upstream] implementation for tests.
//
// TODO(e.burkov): Move this to the golibs some time later.
type FakeUpstream struct {
OnAddress func() (addr string)
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
OnClose func() (err error)
}
// Address implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Address() (addr string) {
return u.OnAddress()
}
// Exchange implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
return u.OnExchange(req)
}
// Close implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Close() (err error) {
return u.OnClose()
}
// TestMessageConstructor is a fake [dnsmsg.MessageConstructor] implementation
// for tests.
type TestMessageConstructor struct {
OnNewMsgNXDOMAIN func(req *dns.Msg) (resp *dns.Msg)
OnNewMsgSERVFAIL func(req *dns.Msg) (resp *dns.Msg)
OnNewMsgNOTIMPLEMENTED func(req *dns.Msg) (resp *dns.Msg)
OnNewMsgNODATA func(req *dns.Msg) (resp *dns.Msg)
}
// NewTestMessageConstructor creates a new *TestMessageConstructor with all it's
// methods set to panic.
func NewTestMessageConstructor() (c *TestMessageConstructor) {
return &TestMessageConstructor{
OnNewMsgNXDOMAIN: func(_ *dns.Msg) (_ *dns.Msg) {
panic("unexpected call of TestMessageConstructor.NewMsgNXDOMAIN")
},
OnNewMsgSERVFAIL: func(_ *dns.Msg) (_ *dns.Msg) {
panic("unexpected call of TestMessageConstructor.NewMsgSERVFAIL")
},
OnNewMsgNOTIMPLEMENTED: func(_ *dns.Msg) (_ *dns.Msg) {
panic("unexpected call of TestMessageConstructor.NewMsgNOTIMPLEMENTED")
},
OnNewMsgNODATA: func(_ *dns.Msg) (_ *dns.Msg) {
panic("unexpected call of TestMessageConstructor.NewMsgNODATA")
},
}
}
// NewMsgNXDOMAIN implements the [MessageConstructor] interface for
// *TestMessageConstructor.
func (c *TestMessageConstructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
return c.OnNewMsgNXDOMAIN(req)
}
// NewMsgSERVFAIL implements the [MessageConstructor] interface for
// *TestMessageConstructor.
func (c *TestMessageConstructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
return c.OnNewMsgSERVFAIL(req)
}
// NewMsgNOTIMPLEMENTED implements the [MessageConstructor] interface for
// *TestMessageConstructor.
func (c *TestMessageConstructor) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
return c.OnNewMsgNOTIMPLEMENTED(req)
}
// NewMsgNODATA implements the [MessageConstructor] interface for
// *TestMessageConstructor.
func (c *TestMessageConstructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
return c.OnNewMsgNODATA(req)
}
0707010000002E000081A4000000000000000000000001679A649F00000162000000000000000000000000000000000000003800000000dnsproxy-0.75.0/internal/dnsproxytest/interface_test.gopackage dnsproxytest_test
import (
"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/dnsproxy/upstream"
)
// type checks
var (
_ upstream.Upstream = (*dnsproxytest.FakeUpstream)(nil)
_ dnsmsg.MessageConstructor = (*dnsproxytest.TestMessageConstructor)(nil)
)
0707010000002F000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002100000000dnsproxy-0.75.0/internal/handler07070100000030000081A4000000000000000000000001679A649F00000E60000000000000000000000000000000000000003000000000dnsproxy-0.75.0/internal/handler/constructor.gopackage handler
import (
"net/netip"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/miekg/dns"
)
// messageConstructor is an extension of the [proxy.MessageConstructor]
// interface that also provides methods for creating DNS responses.
type messageConstructor interface {
proxy.MessageConstructor
// NewCompressedResponse creates a new compressed response message for req
// with the given response code.
NewCompressedResponse(req *dns.Msg, code int) (resp *dns.Msg)
// NewPTRAnswer creates a new resource record for PTR response with the
// given FQDN and PTR domain. Arguments must be fully qualified domain
// names.
NewPTRAnswer(fqdn, ptrFQDN string) (ans *dns.PTR)
// NewIPResponse creates a new A/AAAA response message for req with the
// given IP addresses. All IP addresses must be of the same family.
NewIPResponse(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg)
}
// defaultConstructor is a wrapper for [proxy.MessageConstructor] that also
// implements the [messageConstructor] interface.
//
// TODO(e.burkov): This implementation reflects the one from AdGuard Home,
// consider moving it to [golibs].
type defaultConstructor struct {
proxy.MessageConstructor
}
// type check
var _ messageConstructor = defaultConstructor{}
// NewCompressedResponse implements the [messageConstructor] interface for
// defaultConstructor.
func (defaultConstructor) NewCompressedResponse(req *dns.Msg, code int) (resp *dns.Msg) {
resp = reply(req, code)
resp.Compress = true
return resp
}
// NewPTRAnswer implements the [messageConstructor] interface for
// [defaultConstructor].
func (defaultConstructor) NewPTRAnswer(fqdn, ptrFQDN string) (ans *dns.PTR) {
return &dns.PTR{
Hdr: hdr(fqdn, dns.TypePTR),
Ptr: dns.Fqdn(ptrFQDN),
}
}
// NewIPResponse implements the [messageConstructor] interface for
// [defaultConstructor]
func (c defaultConstructor) NewIPResponse(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
var ans []dns.RR
switch req.Question[0].Qtype {
case dns.TypeA:
ans = genAnswersWithIPv4s(req, ips)
case dns.TypeAAAA:
for _, ip := range ips {
if ip.Is6() {
ans = append(ans, newAnswerAAAA(req, ip))
}
}
default:
// Go on and return an empty response.
}
resp = c.NewCompressedResponse(req, dns.RcodeSuccess)
resp.Answer = ans
return resp
}
// defaultResponseTTL is the default TTL for the DNS responses in seconds.
const defaultResponseTTL = 10
// hdr creates a new DNS header with the given name and RR type.
func hdr(name string, rrType uint16) (h dns.RR_Header) {
return dns.RR_Header{
Name: name,
Rrtype: rrType,
Ttl: defaultResponseTTL,
Class: dns.ClassINET,
}
}
// reply creates a DNS response for req.
func reply(req *dns.Msg, code int) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, code)
resp.RecursionAvailable = true
return resp
}
// newAnswerA creates a DNS A answer for req with the given IP address.
func newAnswerA(req *dns.Msg, ip netip.Addr) (ans *dns.A) {
return &dns.A{
Hdr: hdr(req.Question[0].Name, dns.TypeA),
A: ip.AsSlice(),
}
}
// newAnswerAAAA creates a DNS AAAA answer for req with the given IP address.
func newAnswerAAAA(req *dns.Msg, ip netip.Addr) (ans *dns.AAAA) {
return &dns.AAAA{
Hdr: hdr(req.Question[0].Name, dns.TypeAAAA),
AAAA: ip.AsSlice(),
}
}
// genAnswersWithIPv4s generates DNS A answers provided IPv4 addresses. If any
// of the IPs isn't an IPv4 address, genAnswersWithIPv4s logs a warning and
// returns nil,
func genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.RR) {
for _, ip := range ips {
if !ip.Is4() {
return nil
}
ans = append(ans, newAnswerA(req, ip))
}
return ans
}
07070100000031000081A4000000000000000000000001679A649F00000787000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/handler/default.gopackage handler
import (
"context"
"log/slog"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/hostsfile"
)
// DefaultConfig is the configuration for [Default].
type DefaultConfig struct {
// MessageConstructor constructs DNS messages. It must not be nil.
MessageConstructor proxy.MessageConstructor
// Logger is the logger. It must not be nil.
Logger *slog.Logger
// HostsFiles is the index containing the records of the hosts files.
HostsFiles hostsfile.Storage
// HaltIPv6 halts the processing of AAAA requests and makes the handler
// reply with NODATA to them.
HaltIPv6 bool
}
// Default implements the default configurable [proxy.RequestHandler].
type Default struct {
messages messageConstructor
hosts hostsfile.Storage
logger *slog.Logger
isIPv6Halted bool
}
// NewDefault creates a new [Default] handler.
func NewDefault(conf *DefaultConfig) (d *Default) {
mc, ok := conf.MessageConstructor.(messageConstructor)
if !ok {
mc = defaultConstructor{
MessageConstructor: conf.MessageConstructor,
}
}
return &Default{
logger: conf.Logger,
isIPv6Halted: conf.HaltIPv6,
messages: mc,
hosts: conf.HostsFiles,
}
}
// HandleRequest resolves the DNS request within proxyCtx. It only calls
// [proxy.Proxy.Resolve] if the request isn't handled by any of the internal
// handlers.
func (h *Default) HandleRequest(p *proxy.Proxy, proxyCtx *proxy.DNSContext) (err error) {
// TODO(e.burkov): Use the [*context.Context] instead of
// [*proxy.DNSContext] when the interface-based handler is implemented.
ctx := context.TODO()
h.logger.DebugContext(ctx, "handling request", "req", &proxyCtx.Req.Question[0])
if proxyCtx.Res = h.haltAAAA(ctx, proxyCtx.Req); proxyCtx.Res != nil {
return nil
}
if proxyCtx.Res = h.resolveFromHosts(ctx, proxyCtx.Req); proxyCtx.Res != nil {
return nil
}
return p.Resolve(proxyCtx)
}
07070100000032000081A4000000000000000000000001679A649F0000130F000000000000000000000000000000000000003A00000000dnsproxy-0.75.0/internal/handler/default_internal_test.gopackage handler
import (
"net"
"net/netip"
"os"
"path"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(e.burkov): Remove when [hostsfile.DefaultStorage] stops using [log].
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
os.Exit(m.Run())
}
// TODO(e.burkov): Add helpers to initialize [proxy.Proxy] to [dnsproxytest]
// and rewrite the tests.
// defaultTimeout is a default timeout for tests and contexts.
const defaultTimeout = 1 * time.Second
func TestDefault_haltAAAA(t *testing.T) {
t.Parallel()
reqA := (&dns.Msg{}).SetQuestion("domain.example.", dns.TypeA)
reqAAAA := (&dns.Msg{}).SetQuestion("domain.example.", dns.TypeAAAA)
nodataResp := (&dns.Msg{}).SetReply(reqA)
messages := dnsproxytest.NewTestMessageConstructor()
messages.OnNewMsgNODATA = func(_ *dns.Msg) (resp *dns.Msg) {
return nodataResp
}
t.Run("disabled", func(t *testing.T) {
t.Parallel()
hdlr := NewDefault(&DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
MessageConstructor: messages,
HaltIPv6: false,
})
ctx := testutil.ContextWithTimeout(t, defaultTimeout)
assert.Nil(t, hdlr.haltAAAA(ctx, reqA))
assert.Nil(t, hdlr.haltAAAA(ctx, reqAAAA))
})
t.Run("enabled", func(t *testing.T) {
t.Parallel()
hdlr := NewDefault(&DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
MessageConstructor: messages,
HaltIPv6: true,
})
ctx := testutil.ContextWithTimeout(t, defaultTimeout)
assert.Nil(t, hdlr.haltAAAA(ctx, reqA))
assert.Equal(t, nodataResp, hdlr.haltAAAA(ctx, reqAAAA))
})
}
func TestDefault_resolveFromHosts(t *testing.T) {
t.Parallel()
// TODO(e.burkov): Use the one from [dnsproxytest].
messages := dnsmsg.DefaultMessageConstructor{}
relPath := path.Join("testdata", t.Name(), "hosts")
absPath, err := filepath.Abs(path.Join("testdata", t.Name(), "hosts"))
require.NoError(t, err)
strg, err := ReadHosts([]string{absPath, relPath})
require.NoError(t, err)
hdlr := NewDefault(&DefaultConfig{
MessageConstructor: messages,
Logger: slogutil.NewDiscardLogger(),
HostsFiles: strg,
HaltIPv6: true,
})
const (
fqdnV4 = "ipv4.domain.example."
fqdnV6 = "ipv6.domain.example."
)
var (
addrV4 = netip.MustParseAddr("1.2.3.4")
addrV6 = netip.MustParseAddr("2001:db8::1")
reversedV4 = errors.Must(netutil.IPToReversedAddr(addrV4.AsSlice()))
reversedV6 = errors.Must(netutil.IPToReversedAddr(addrV6.AsSlice()))
unknownReversed = errors.Must(netutil.IPToReversedAddr(net.IP{4, 3, 2, 1}))
)
testCases := []struct {
wantAns dns.RR
req *dns.Msg
name string
}{{
wantAns: &dns.A{
Hdr: dns.RR_Header{
Name: fqdnV4,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 10,
},
A: addrV4.AsSlice(),
},
req: (&dns.Msg{}).SetQuestion(fqdnV4, dns.TypeA),
name: "success_a",
}, {
wantAns: &dns.AAAA{
Hdr: dns.RR_Header{
Name: fqdnV6,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 10,
},
AAAA: addrV6.AsSlice(),
},
req: (&dns.Msg{}).SetQuestion(fqdnV6, dns.TypeAAAA),
name: "success_aaaa",
}, {
wantAns: &dns.PTR{
Hdr: dns.RR_Header{
Name: reversedV4,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 10,
},
Ptr: fqdnV4,
},
req: (&dns.Msg{}).SetQuestion(reversedV4, dns.TypePTR),
name: "success_ptr_v4",
}, {
wantAns: &dns.PTR{
Hdr: dns.RR_Header{
Name: reversedV6,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 10,
},
Ptr: fqdnV6,
},
req: (&dns.Msg{}).SetQuestion(reversedV6, dns.TypePTR),
name: "success_ptr_v6",
}, {
wantAns: nil,
req: (&dns.Msg{}).SetQuestion("unknown.example", dns.TypeA),
name: "not_found_a",
}, {
wantAns: nil,
req: (&dns.Msg{}).SetQuestion("unknown.example", dns.TypeAAAA),
name: "not_found_aaaa",
}, {
wantAns: nil,
req: (&dns.Msg{}).SetQuestion(unknownReversed, dns.TypePTR),
name: "not_found_ptr",
}, {
wantAns: nil,
req: (&dns.Msg{}).SetQuestion("bad.ptr", dns.TypePTR),
name: "bad_ptr",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.ContextWithTimeout(t, defaultTimeout)
resp := hdlr.resolveFromHosts(ctx, tc.req)
if tc.wantAns == nil {
assert.Nil(t, resp)
return
}
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
assert.Equal(t, tc.wantAns, resp.Answer[0])
})
}
}
07070100000033000081A4000000000000000000000001679A649F0000006F000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/handler/handler.go// Package handler provides some customizable DNS request handling logic used in
// the proxy.
package handler
07070100000034000081A4000000000000000000000001679A649F00000D50000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/internal/handler/hosts.gopackage handler
import (
"context"
"fmt"
"net/netip"
"os"
"slices"
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// emptyStorage is a [hostsfile.Storage] that contains no records.
//
// TODO(e.burkov): Move to [hostsfile].
type emptyStorage [0]hostsfile.Record
// type check
var _ hostsfile.Storage = emptyStorage{}
// ByAddr implements the [hostsfile.Storage] interface for [emptyStorage].
func (emptyStorage) ByAddr(_ netip.Addr) (names []string) {
return nil
}
// ByName implements the [hostsfile.Storage] interface for [emptyStorage].
func (emptyStorage) ByName(_ string) (addrs []netip.Addr) {
return nil
}
// ReadHosts reads the hosts files from the file system and returns a storage
// with parsed records. strg is always usable even if an error occurred.
func ReadHosts(paths []string) (strg hostsfile.Storage, err error) {
// Don't check the error since it may only appear when any readers used.
defaultStrg, _ := hostsfile.NewDefaultStorage()
var errs []error
for _, path := range paths {
err = readHostsFile(defaultStrg, path)
if err != nil {
// Don't wrap the error since it's informative enough as is.
errs = append(errs, err)
}
}
// TODO(e.burkov): Add method for length.
isEmpty := true
defaultStrg.RangeAddrs(func(_ string, _ []netip.Addr) (cont bool) {
isEmpty = false
return false
})
if isEmpty {
return emptyStorage{}, errors.Join(errs...)
}
return defaultStrg, errors.Join(errs...)
}
// readHostsFile reads the hosts file at path and parses it into strg.
func readHostsFile(strg *hostsfile.DefaultStorage, path string) (err error) {
// #nosec G304 -- Trust the file path from the configuration file.
f, err := os.Open(path)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
err = hostsfile.Parse(strg, f, nil)
if err != nil {
return fmt.Errorf("parsing hosts file %q: %w", path, err)
}
return nil
}
// resolveFromHosts resolves the DNS query from the hosts file. It fills the
// response with the A, AAAA, and PTR records from the hosts file.
func (h *Default) resolveFromHosts(ctx context.Context, req *dns.Msg) (resp *dns.Msg) {
var addrs []netip.Addr
var ptrs []string
q := req.Question[0]
name := strings.TrimSuffix(q.Name, ".")
switch q.Qtype {
case dns.TypeA:
addrs = slices.Clone(h.hosts.ByName(name))
addrs = slices.DeleteFunc(addrs, netip.Addr.Is6)
case dns.TypeAAAA:
addrs = slices.Clone(h.hosts.ByName(name))
addrs = slices.DeleteFunc(addrs, netip.Addr.Is4)
case dns.TypePTR:
addr, err := netutil.IPFromReversedAddr(name)
if err != nil {
h.logger.DebugContext(ctx, "failed parsing ptr", slogutil.KeyError, err)
return nil
}
ptrs = h.hosts.ByAddr(addr)
default:
return nil
}
switch {
case len(addrs) > 0:
resp = h.messages.NewIPResponse(req, addrs)
case len(ptrs) > 0:
resp = h.messages.NewCompressedResponse(req, dns.RcodeSuccess)
name = req.Question[0].Name
for _, ptr := range ptrs {
resp.Answer = append(resp.Answer, h.messages.NewPTRAnswer(name, dns.Fqdn(ptr)))
}
default:
h.logger.DebugContext(ctx, "no hosts records found", "name", name, "qtype", q.Qtype)
}
return resp
}
07070100000035000081A4000000000000000000000001679A649F000001E2000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/internal/handler/ipv6halt.gopackage handler
import (
"context"
"github.com/miekg/dns"
)
// haltAAAA halts the processing of AAAA requests if IPv6 is disabled. req must
// not be nil.
func (h *Default) haltAAAA(ctx context.Context, req *dns.Msg) (resp *dns.Msg) {
if h.isIPv6Halted && req.Question[0].Qtype == dns.TypeAAAA {
h.logger.DebugContext(
ctx,
"ipv6 is disabled; replying with empty response",
"req", req.Question[0].Name,
)
return h.messages.NewMsgNODATA(req)
}
return nil
}
07070100000036000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/internal/handler/testdata07070100000037000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000004700000000dnsproxy-0.75.0/internal/handler/testdata/TestDefault_resolveFromHosts07070100000038000081A4000000000000000000000001679A649F0000004A000000000000000000000000000000000000004D00000000dnsproxy-0.75.0/internal/handler/testdata/TestDefault_resolveFromHosts/hosts1.2.3.4 ipv4.domain.example
2001:db8::1 ipv6.domain.example
# comment
07070100000039000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002100000000dnsproxy-0.75.0/internal/netutil0707010000003A000081A4000000000000000000000001679A649F000002AC000000000000000000000000000000000000003100000000dnsproxy-0.75.0/internal/netutil/listenconfig.gopackage netutil
import (
"log/slog"
"net"
)
// ListenConfig returns the default [net.ListenConfig] used by the plain-DNS
// servers in this module. l must not be nil.
//
// TODO(a.garipov): Add tests.
//
// TODO(a.garipov): Add an option to not set SO_REUSEPORT on Unix to prevent
// issues with OpenWrt.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/5872.
//
// TODO(a.garipov): DRY with AdGuard DNS when we can.
func ListenConfig(l *slog.Logger) (lc *net.ListenConfig) {
return &net.ListenConfig{
Control: listenControl{logger: l}.defaultListenControl,
}
}
// listenControl is a wrapper struct with logger.
type listenControl struct {
logger *slog.Logger
}
0707010000003B000081A4000000000000000000000001679A649F0000048A000000000000000000000000000000000000003600000000dnsproxy-0.75.0/internal/netutil/listenconfig_unix.go//go:build unix
package netutil
import (
"fmt"
"syscall"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"golang.org/x/sys/unix"
)
// defaultListenControl is used as a [net.ListenConfig.Control] function to set
// the SO_REUSEADDR and SO_REUSEPORT socket options on all sockets used by the
// DNS servers in this module.
func (lc listenControl) defaultListenControl(_, _ string, c syscall.RawConn) (err error) {
var opErr error
err = c.Control(func(fd uintptr) {
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
if opErr != nil {
opErr = fmt.Errorf("setting SO_REUSEADDR: %w", opErr)
return
}
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
if opErr != nil {
if errors.Is(opErr, unix.ENOPROTOOPT) {
// Some Linux OSs do not seem to support SO_REUSEPORT, including
// some varieties of OpenWrt. Issue a warning.
lc.logger.Warn("SO_REUSEPORT not supported", slogutil.KeyError, opErr)
opErr = nil
} else {
opErr = fmt.Errorf("setting SO_REUSEPORT: %w", opErr)
}
}
})
return errors.WithDeferred(opErr, err)
}
0707010000003C000081A4000000000000000000000001679A649F000000F4000000000000000000000000000000000000003900000000dnsproxy-0.75.0/internal/netutil/listenconfig_windows.go//go:build windows
package netutil
import "syscall"
// defaultListenControl is nil on Windows, because it doesn't support
// SO_REUSEPORT.
func (listenControl) defaultListenControl(_, _ string, _ syscall.RawConn) (err error) {
return nil
}
0707010000003D000081A4000000000000000000000001679A649F00000302000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/netutil/netutil.go// Package netutil contains network-related utilities common among dnsproxy
// packages.
//
// TODO(a.garipov): Move improved versions of these into netutil in module
// golibs.
package netutil
import (
"net/netip"
"strings"
)
// ParseSubnet parses s either as a CIDR prefix itself, or as an IP address,
// returning the corresponding single-IP CIDR prefix.
//
// TODO(e.burkov): Replace usages with [netutil.Prefix].
func ParseSubnet(s string) (p netip.Prefix, err error) {
if strings.Contains(s, "/") {
p, err = netip.ParsePrefix(s)
if err != nil {
return netip.Prefix{}, err
}
} else {
var ip netip.Addr
ip, err = netip.ParseAddr(s)
if err != nil {
return netip.Prefix{}, err
}
p = netip.PrefixFrom(ip, ip.BitLen())
}
return p, nil
}
0707010000003E000081A4000000000000000000000001679A649F00000128000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/internal/netutil/paths.gopackage netutil
// DefaultHostsPaths returns the slice of default paths to system hosts files.
//
// TODO(s.chzhen): Since [fs.FS] is no longer needed, update the
// [hostsfile.DefaultHostsPaths] from golibs.
func DefaultHostsPaths() (paths []string, err error) {
return defaultHostsPaths()
}
0707010000003F000081A4000000000000000000000001679A649F000001C4000000000000000000000000000000000000002F00000000dnsproxy-0.75.0/internal/netutil/paths_unix.go//go:build unix
package netutil
import "github.com/AdguardTeam/golibs/hostsfile"
// defaultHostsPaths returns default paths to hosts files for UNIX.
func defaultHostsPaths() (paths []string, err error) {
paths, err = hostsfile.DefaultHostsPaths()
if err != nil {
// Should not happen because error is always nil.
panic(err)
}
res := make([]string, 0, len(paths))
for _, p := range paths {
res = append(res, "/"+p)
}
return res, nil
}
07070100000040000081A4000000000000000000000001679A649F000001B1000000000000000000000000000000000000003200000000dnsproxy-0.75.0/internal/netutil/paths_windows.go//go:build windows
package netutil
import (
"fmt"
"path"
"golang.org/x/sys/windows"
)
// defaultHostsPaths returns default paths to hosts files for Windows.
func defaultHostsPaths() (paths []string, err error) {
sysDir, err := windows.GetSystemDirectory()
if err != nil {
return []string{}, fmt.Errorf("getting system directory: %w", err)
}
p := path.Join(sysDir, "drivers", "etc", "hosts")
return []string{p}, nil
}
07070100000041000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/internal/netutil/testdata07070100000042000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000003400000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts07070100000043000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000003D00000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts/bad_file07070100000044000081A4000000000000000000000001679A649F000000B6000000000000000000000000000000000000004300000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts/bad_file/hosts# comment about the following empty line
# comment about the above empty line
1.2.3.256 a.b # invalid address
1.2.3.4 a.123 # invalid top-level domain
1.2.3.4 .a.b # empty domain
07070100000045000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000003E00000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts/good_file07070100000046000081A4000000000000000000000001679A649F0000033A000000000000000000000000000000000000004400000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts/good_file/hosts# IPv4
# 1st host.
0.0.0.1 Host.One
# 2nd host.
0.0.0.2 Host.Two
# 1st host full duplicate.
0.0.0.1 host.one
# 2nd host duplicate with new name.
0.0.0.2 host.two Host.New
# 1st host with foreign name.
0.0.0.1 host.new
# 2nd host new name.
0.0.0.2 Again.Host.Two
# Mapped
# 1st host.
::ffff:0.0.0.1 Host.One
# 2nd host.
::ffff:0.0.0.2 Host.Two
# 1st host full duplicate.
::ffff:0.0.0.1 host.one
# 2nd host duplicate with new name.
::ffff:0.0.0.2 host.two Host.New
# 1st host with foreign name.
::ffff:0.0.0.1 host.new
# 2nd host new name.
::ffff:0.0.0.2 Again.Host.Two
# IPv6
# 1st host.
::1 Host.One
# 2nd host.
::2 Host.Two
# 1st host full duplicate.
::1 host.one
# 2nd host duplicate with new name.
::2 host.two Host.New
# 1st host with foreign name.
::1 host.new
# 2nd host new name.
::2 Again.Host.Two
07070100000047000081A4000000000000000000000001679A649F0000042A000000000000000000000000000000000000002800000000dnsproxy-0.75.0/internal/netutil/udp.gopackage netutil
import (
"net"
"net/netip"
)
// UDPGetOOBSize returns maximum size of the received OOB data.
func UDPGetOOBSize() (oobSize int) {
return udpGetOOBSize()
}
// UDPSetOptions sets flag options on a UDP socket to be able to receive the
// necessary OOB data.
func UDPSetOptions(c *net.UDPConn) (err error) {
return udpSetOptions(c)
}
// UDPRead reads the message from conn using buf and receives a control-message
// payload of size udpOOBSize from it. It returns the number of bytes copied
// into buf and the source address of the message.
//
// TODO(s.chzhen): Consider using netip.Addr.
func UDPRead(
conn *net.UDPConn,
buf []byte,
udpOOBSize int,
) (n int, localIP netip.Addr, remoteAddr *net.UDPAddr, err error) {
return udpRead(conn, buf, udpOOBSize)
}
// UDPWrite writes the data to the remoteAddr using conn.
//
// TODO(s.chzhen): Consider using netip.Addr.
func UDPWrite(
data []byte,
conn *net.UDPConn,
remoteAddr *net.UDPAddr,
localIP netip.Addr,
) (n int, err error) {
return udpWrite(data, conn, remoteAddr, localIP)
}
07070100000048000081A4000000000000000000000001679A649F00000856000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/internal/netutil/udp_unix.go//go:build unix
package netutil
import (
"fmt"
"net"
"net/netip"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// These are the set of socket option flags for configuring an IPv[46] UDP
// connection to receive an appropriate OOB data. For both versions the flags
// are:
//
// - FlagDst
// - FlagInterface
const (
ipv4Flags ipv4.ControlFlags = ipv4.FlagDst | ipv4.FlagInterface
ipv6Flags ipv6.ControlFlags = ipv6.FlagDst | ipv6.FlagInterface
)
// udpGetOOBSize obtains the destination IP from OOB data.
func udpGetOOBSize() (oobSize int) {
return max(len(ipv4.NewControlMessage(ipv4Flags)), len(ipv6.NewControlMessage(ipv6Flags)))
}
func udpSetOptions(c *net.UDPConn) (err error) {
err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6Flags, true)
err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4Flags, true)
if err6 != nil && err4 != nil {
return fmt.Errorf("failed to call SetControlMessage: ipv4: %v; ipv6: %v", err4, err6)
}
return nil
}
func udpGetDstFromOOB(oob []byte) (dst netip.Addr, err error) {
cm6 := &ipv6.ControlMessage{}
if cm6.Parse(oob) == nil && cm6.Dst != nil {
// Linux maps IPv4 addresses to IPv6 ones by default, so we can get an
// IPv4 dst from an IPv6 control-message.
return netutil.IPToAddrNoMapped(cm6.Dst)
}
cm4 := &ipv4.ControlMessage{}
if cm4.Parse(oob) == nil && cm4.Dst != nil {
return netutil.IPToAddr(cm4.Dst, netutil.AddrFamilyIPv4)
}
return netip.Addr{}, nil
}
func udpRead(
c *net.UDPConn,
buf []byte,
udpOOBSize int,
) (n int, localIP netip.Addr, remoteAddr *net.UDPAddr, err error) {
var oobn int
oob := make([]byte, udpOOBSize)
n, oobn, _, remoteAddr, err = c.ReadMsgUDP(buf, oob)
if err != nil {
return -1, netip.Addr{}, nil, err
}
localIP, err = udpGetDstFromOOB(oob[:oobn])
if err != nil {
return -1, netip.Addr{}, nil, err
}
return n, localIP, remoteAddr, nil
}
func udpWrite(
data []byte,
conn *net.UDPConn,
remoteAddr *net.UDPAddr,
localIP netip.Addr,
) (n int, err error) {
n, _, err = conn.WriteMsgUDP(data, udpMakeOOBWithSrc(localIP), remoteAddr)
return n, err
}
07070100000049000081A4000000000000000000000001679A649F00000229000000000000000000000000000000000000003000000000dnsproxy-0.75.0/internal/netutil/udp_windows.go//go:build windows
package netutil
import (
"net"
"net/netip"
)
func udpGetOOBSize() int {
return 0
}
func udpSetOptions(c *net.UDPConn) error {
return nil
}
func udpRead(c *net.UDPConn, buf []byte, _ int) (int, netip.Addr, *net.UDPAddr, error) {
n, addr, err := c.ReadFrom(buf)
var udpAddr *net.UDPAddr
if addr != nil {
udpAddr = addr.(*net.UDPAddr)
}
return n, netip.Addr{}, udpAddr, err
}
func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, _ netip.Addr) (int, error) {
return conn.WriteTo(bytes, remoteAddr)
}
0707010000004A000081A4000000000000000000000001679A649F0000026F000000000000000000000000000000000000003200000000dnsproxy-0.75.0/internal/netutil/udpoob_darwin.go//go:build darwin
package netutil
import (
"net/netip"
"golang.org/x/net/ipv6"
)
// udpMakeOOBWithSrc makes the OOB data with the specified source IP.
func udpMakeOOBWithSrc(ip netip.Addr) (b []byte) {
if ip.Is4() {
// Do not set the IPv4 source address via OOB, because it can cause the
// address to become unspecified on darwin.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2807.
//
// TODO(e.burkov): Develop a workaround to make it write OOB only when
// listening on an unspecified address.
return []byte{}
}
return (&ipv6.ControlMessage{
Src: ip.AsSlice(),
}).Marshal()
}
0707010000004B000081A4000000000000000000000001679A649F00000186000000000000000000000000000000000000003200000000dnsproxy-0.75.0/internal/netutil/udpoob_others.go//go:build !darwin
package netutil
import (
"net/netip"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// udpMakeOOBWithSrc makes the OOB data with the specified source IP.
func udpMakeOOBWithSrc(ip netip.Addr) (b []byte) {
if ip.Is4() {
return (&ipv4.ControlMessage{
Src: ip.AsSlice(),
}).Marshal()
}
return (&ipv6.ControlMessage{
Src: ip.AsSlice(),
}).Marshal()
}
0707010000004C000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/internal/tools0707010000004D000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002100000000dnsproxy-0.75.0/internal/version0707010000004E000081A4000000000000000000000001679A649F00000351000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/version/version.go// Package version contains dnsproxy version information.
package version
// Versions
// These are set by the linker. Unfortunately, we cannot set constants during
// linking, and Go doesn't have a concept of immutable variables, so to be
// thorough we have to only export them through getters.
var (
branch string
committime string
revision string
version string
)
// Branch returns the compiled-in value of the Git branch.
func Branch() (b string) {
return branch
}
// CommitTime returns the compiled-in value of the build time as a string.
func CommitTime() (t string) {
return committime
}
// Revision returns the compiled-in value of the Git revision.
func Revision() (r string) {
return revision
}
// Version returns the compiled-in value of the build version as a string.
func Version() (v string) {
return version
}
0707010000004F000081A4000000000000000000000001679A649F00000066000000000000000000000000000000000000001800000000dnsproxy-0.75.0/main.gopackage main
import (
"github.com/AdguardTeam/dnsproxy/internal/cmd"
)
func main() {
cmd.Main()
}
07070100000050000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001600000000dnsproxy-0.75.0/proxy07070100000051000081A4000000000000000000000001679A649F00000A46000000000000000000000000000000000000002700000000dnsproxy-0.75.0/proxy/beforerequest.gopackage proxy
import (
"fmt"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
)
// BeforeRequestError is an error that signals that the request should be
// responded with the given response message.
type BeforeRequestError struct {
// Err is the error that caused the response. It must not be nil.
Err error
// Response is the response message to be sent to the client. It must be a
// valid response message.
Response *dns.Msg
}
// type check
var _ error = (*BeforeRequestError)(nil)
// Error implements the [error] interface for *BeforeRequestError.
func (e *BeforeRequestError) Error() (msg string) {
return fmt.Sprintf("%s; respond with %s", e.Err, dns.RcodeToString[e.Response.Rcode])
}
// type check
var _ errors.Wrapper = (*BeforeRequestError)(nil)
// Unwrap implements the [errors.Wrapper] interface for *BeforeRequestError.
func (e *BeforeRequestError) Unwrap() (unwrapped error) {
return e.Err
}
// BeforeRequestHandler is an object that can handle the request before it's
// processed by [Proxy].
type BeforeRequestHandler interface {
// HandleBefore is called before each DNS request is started processing.
// The passed [DNSContext] contains the Req, Addr, and IsLocalClient fields
// set accordingly.
//
// If returned err is a [BeforeRequestError], the given response message is
// used. If err is nil, the request is processed further. [Proxy] assumes
// a handler itself doesn't set the [DNSContext.Res] field.
HandleBefore(p *Proxy, dctx *DNSContext) (err error)
}
// noopRequestHandler is a no-op implementation of [BeforeRequestHandler] that
// always returns nil.
type noopRequestHandler struct{}
// type check
var _ BeforeRequestHandler = noopRequestHandler{}
// HandleBefore implements the [BeforeRequestHandler] interface for
// noopRequestHandler.
func (noopRequestHandler) HandleBefore(_ *Proxy, _ *DNSContext) (err error) {
return nil
}
// handleBefore calls the [BeforeRequestHandler] if it's set. If the returned
// error is nil, it returns true and the request is processed further. If the
// returned error has type [BeforeRequestError], the specified response is sent
// to the client. Otherwise, the request just ignored.
func (p *Proxy) handleBefore(d *DNSContext) (cont bool) {
err := p.beforeRequestHandler.HandleBefore(p, d)
if err == nil {
return true
}
p.logger.Debug("handling before request", slogutil.KeyError, err)
if befReqErr := (&BeforeRequestError{}); errors.As(err, &befReqErr) {
d.Res = befReqErr.Response
p.logDNSMessage(d.Res)
p.respond(d)
}
return false
}
07070100000052000081A4000000000000000000000001679A649F00000D14000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/proxy/beforerequest_test.gopackage proxy
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testBeforeRequestHandler is a mock before request handler implementation to
// simplify testing.
type testBeforeRequestHandler struct {
onHandleBefore func(p *Proxy, dctx *DNSContext) (err error)
}
// type check
var _ BeforeRequestHandler = (*testBeforeRequestHandler)(nil)
// HandleBefore implements the [BeforeRequestHandler] interface for
// *testBeforeRequestHandler.
func (h *testBeforeRequestHandler) HandleBefore(p *Proxy, dctx *DNSContext) (err error) {
return h.onHandleBefore(p, dctx)
}
func TestProxy_HandleDNSRequest_beforeRequestHandler(t *testing.T) {
t.Parallel()
const (
allowedID = iota
droppedID
errorID
)
allowedRequest := (&dns.Msg{}).SetQuestion("allowed.", dns.TypeA)
allowedRequest.Id = allowedID
allowedResponse := (&dns.Msg{}).SetReply(allowedRequest)
droppedRequest := (&dns.Msg{}).SetQuestion("dropped.", dns.TypeA)
droppedRequest.Id = droppedID
errorRequest := (&dns.Msg{}).SetQuestion("error.", dns.TypeA)
errorRequest.Id = errorID
errorResponse := (&dns.Msg{}).SetReply(errorRequest)
p := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{&fakeUpstream{
onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
return allowedResponse.Copy(), nil
},
onAddress: func() (addr string) { return "general" },
onClose: func() (err error) { return nil },
}},
},
TrustedProxies: defaultTrustedProxies,
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
BeforeRequestHandler: &testBeforeRequestHandler{
onHandleBefore: func(p *Proxy, dctx *DNSContext) (err error) {
switch dctx.Req.Id {
case allowedID:
return nil
case droppedID:
return errors.Error("just drop")
case errorID:
return &BeforeRequestError{
Err: errors.Error("just error"),
Response: errorResponse,
}
default:
panic(fmt.Sprintf("unexpected request id: %d", dctx.Req.Id))
}
},
},
})
ctx := context.Background()
require.NoError(t, p.Start(ctx))
testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) })
client := &dns.Client{
Net: string(ProtoTCP),
Timeout: 200 * time.Millisecond,
}
addr := p.Addr(ProtoTCP).String()
t.Run("allowed", func(t *testing.T) {
t.Parallel()
resp, _, err := client.Exchange(allowedRequest, addr)
require.NoError(t, err)
assert.Equal(t, allowedResponse, resp)
})
t.Run("dropped", func(t *testing.T) {
t.Parallel()
resp, _, err := client.Exchange(droppedRequest, addr)
wantErr := &net.OpError{}
require.ErrorAs(t, err, &wantErr)
assert.True(t, wantErr.Timeout())
assert.Nil(t, resp)
})
t.Run("error", func(t *testing.T) {
t.Parallel()
resp, _, err := client.Exchange(errorRequest, addr)
require.NoError(t, err)
assert.Equal(t, errorResponse, resp)
})
}
07070100000053000081A4000000000000000000000001679A649F000002AF000000000000000000000000000000000000002700000000dnsproxy-0.75.0/proxy/bogusnxdomain.gopackage proxy
import (
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// isBogusNXDomain returns true if m contains at least a single IP address in
// the Answer section contained in BogusNXDomain subnets of p.
func (p *Proxy) isBogusNXDomain(m *dns.Msg) (ok bool) {
if m == nil || len(p.BogusNXDomain) == 0 || len(m.Question) == 0 {
return false
} else if qt := m.Question[0].Qtype; qt != dns.TypeA && qt != dns.TypeAAAA {
return false
}
set := netutil.SliceSubnetSet(p.BogusNXDomain)
for _, rr := range m.Answer {
ip := proxyutil.IPFromRR(rr)
if set.Contains(ip) {
return true
}
}
return false
}
07070100000054000081A4000000000000000000000001679A649F00000B72000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/proxy/bogusnxdomain_test.gopackage proxy
import (
"context"
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProxy_IsBogusNXDomain(t *testing.T) {
prx := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
CacheEnabled: true,
BogusNXDomain: []netip.Prefix{
netip.MustParsePrefix("4.3.2.1/24"),
netip.MustParsePrefix("1.2.3.4/8"),
netip.MustParsePrefix("10.11.12.13/32"),
netip.MustParsePrefix("102:304:506:708:90a:b0c:d0e:f10/120"),
},
})
testCases := []struct {
name string
ans []dns.RR
wantRcode int
}{{
name: "bogus_subnet",
ans: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
A: net.ParseIP("4.3.2.1"),
}},
wantRcode: dns.RcodeNameError,
}, {
name: "bogus_big_subnet",
ans: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
A: net.ParseIP("1.254.254.254"),
}},
wantRcode: dns.RcodeNameError,
}, {
name: "bogus_single_ip",
ans: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
A: net.ParseIP("10.11.12.13"),
}},
wantRcode: dns.RcodeNameError,
}, {
name: "bogus_6",
ans: []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10},
AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 99},
}},
wantRcode: dns.RcodeNameError,
}, {
name: "non-bogus",
ans: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
A: net.ParseIP("10.11.12.14"),
}},
wantRcode: dns.RcodeSuccess,
}, {
name: "non-bogus_6",
ans: []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10},
AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 15},
}},
wantRcode: dns.RcodeSuccess,
}}
u := testUpstream{}
prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u}
ctx := context.Background()
err := prx.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return prx.Shutdown(ctx) })
d := &DNSContext{
Req: newHostTestMessage("host"),
}
for _, tc := range testCases {
u.ans = tc.ans
t.Run(tc.name, func(t *testing.T) {
err = prx.Resolve(d)
require.NoError(t, err)
require.NotNil(t, d.Res)
assert.Equal(t, tc.wantRcode, d.Res.Rcode)
})
}
}
07070100000055000081A4000000000000000000000001679A649F00003F27000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/cache.gopackage proxy
import (
"bytes"
"encoding/binary"
"log/slog"
"math"
"net"
"slices"
"strings"
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
glcache "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/miekg/dns"
)
// defaultCacheSize is the size of cache in bytes by default.
const defaultCacheSize = 64 * 1024
// cache is used to cache requests and used upstreams.
type cache struct {
// itemsLock protects requests cache.
itemsLock *sync.RWMutex
// itemsWithSubnetLock protects requests cache.
itemsWithSubnetLock *sync.RWMutex
// items is the requests cache.
items glcache.Cache
// itemsWithSubnet is the requests cache.
itemsWithSubnet glcache.Cache
// optimistic defines if the cache should return expired items and resolve
// those again.
optimistic bool
}
// cacheItem is a single cache entry. It's a helper type to aggregate the
// item-specific logic.
type cacheItem struct {
// m contains the cached response.
m *dns.Msg
// u contains an address of the upstream which resolved m.
u string
// ttl is the time-to-live value for the item. Should be set before calling
// [cacheItem.pack].
ttl uint32
}
// respToItem converts the pair of the response and upstream resolved the one
// into item for storing it in cache. l must not be nil.
func (c *cache) respToItem(m *dns.Msg, u upstream.Upstream, l *slog.Logger) (item *cacheItem) {
ttl := cacheTTL(m, l)
if ttl == 0 {
return nil
}
upsAddr := ""
if u != nil {
upsAddr = u.Address()
}
return &cacheItem{
m: m,
u: upsAddr,
ttl: ttl,
}
}
const (
// packedMsgLenSz is the exact length of byte slice capable to store the
// length of packed DNS message. It's essentially the size of a uint16.
packedMsgLenSz = 2
// expTimeSz is the exact length of byte slice capable to store the
// expiration time the response. It's essentially the size of a uint32.
expTimeSz = 4
// minPackedLen is the minimum length of the packed cacheItem.
minPackedLen = expTimeSz + packedMsgLenSz
)
// pack converts the ci into bytes slice.
func (ci *cacheItem) pack() (packed []byte) {
pm, _ := ci.m.Pack()
pmLen := len(pm)
packed = make([]byte, minPackedLen, minPackedLen+pmLen+len(ci.u))
// Put expiration time.
binary.BigEndian.PutUint32(packed, uint32(time.Now().Unix())+ci.ttl)
// Put the length of the packed message.
binary.BigEndian.PutUint16(packed[expTimeSz:], uint16(pmLen))
// Put the packed message itself.
packed = append(packed, pm...)
// Put the address of the upstream.
packed = append(packed, ci.u...)
return packed
}
// optimisticTTL is the default TTL for expired cached responses in seconds.
const optimisticTTL = 10
// unpackItem converts the data into cacheItem using req as a request message.
// expired is true if the item exists but expired. The expired cached items are
// only returned if c is optimistic. req must not be nil.
func (c *cache) unpackItem(data []byte, req *dns.Msg) (ci *cacheItem, expired bool) {
if len(data) < minPackedLen {
return nil, false
}
b := bytes.NewBuffer(data)
expire := int64(binary.BigEndian.Uint32(b.Next(expTimeSz)))
now := time.Now().Unix()
var ttl uint32
if expired = expire <= now; expired {
if !c.optimistic {
return nil, expired
}
ttl = optimisticTTL
} else {
ttl = uint32(expire - now)
}
l := int(binary.BigEndian.Uint16(b.Next(packedMsgLenSz)))
if l == 0 {
return nil, expired
}
m := &dns.Msg{}
if m.Unpack(b.Next(l)) != nil {
return nil, expired
}
res := (&dns.Msg{}).SetRcode(req, m.Rcode)
res.AuthenticatedData = m.AuthenticatedData
res.RecursionAvailable = m.RecursionAvailable
var doBit bool
if o := req.IsEdns0(); o != nil {
doBit = o.Do()
}
// Don't return OPT records from cache since it's deprecated by RFC 6891.
// If the request has DO bit set we only remove all the OPT RRs, and also
// all DNSSEC RRs otherwise.
filterMsg(res, m, req.AuthenticatedData, doBit, ttl)
return &cacheItem{
m: res,
u: string(b.Next(b.Len())),
}, expired
}
// initCache initializes cache if it's enabled.
func (p *Proxy) initCache() {
if !p.CacheEnabled {
p.logger.Info("cache disabled")
return
}
size := p.CacheSizeBytes
p.logger.Info("cache enabled", "size", size)
p.cache = newCache(size, p.EnableEDNSClientSubnet, p.CacheOptimistic)
p.shortFlighter = newOptimisticResolver(p)
}
// newCache returns a properly initialized cache. logger must not be nil.
func newCache(size int, withECS, optimistic bool) (c *cache) {
c = &cache{
itemsLock: &sync.RWMutex{},
itemsWithSubnetLock: &sync.RWMutex{},
items: createCache(size),
optimistic: optimistic,
}
if withECS {
c.itemsWithSubnet = createCache(size)
}
return c
}
// get returns cached item for the req if it's found. expired is true if the
// item's TTL is expired. key is the resulting key for req. It's returned to
// avoid recalculating it afterwards.
func (c *cache) get(req *dns.Msg) (ci *cacheItem, expired bool, key []byte) {
c.itemsLock.RLock()
defer c.itemsLock.RUnlock()
if !canLookUpInCache(c.items, req) {
return nil, false, nil
}
key = msgToKey(req)
data := c.items.Get(key)
if data == nil {
return nil, false, key
}
if ci, expired = c.unpackItem(data, req); ci == nil {
c.items.Del(key)
}
return ci, expired, key
}
// getWithSubnet returns cached item for the req if it's found by n. expired
// is true if the item's TTL is expired. k is the resulting key for req. It's
// returned to avoid recalculating it afterwards.
//
// Note that a slow longest-prefix-match algorithm is used, so cache searches
// are performed up to mask+1 times.
func (c *cache) getWithSubnet(req *dns.Msg, n *net.IPNet) (ci *cacheItem, expired bool, k []byte) {
c.itemsWithSubnetLock.RLock()
defer c.itemsWithSubnetLock.RUnlock()
if !canLookUpInCache(c.itemsWithSubnet, req) {
return nil, false, nil
}
ecsIP := n.IP.Mask(n.Mask)
ipLen := len(ecsIP)
m, _ := n.Mask.Size()
k = msgToKeyWithSubnet(req, ecsIP, m)
data := c.itemsWithSubnet.Get(k)
// In order to reduce allocations we apply mask on bits level. As the key
// k has ecsIP in bytes slice representation, each iteration we can just
// clear one bit in the end of it by applying the bitmask.
for bitmask := ^byte(0); m >= 0 && data == nil; m-- {
// Set mask identification byte in the key.
k[keyMaskIndex] = byte(m)
// In case mask is zero, the key doesn't have IP in it.
if m == 0 {
k = slices.Delete(k, keyIPIndex, keyIPIndex+ipLen)
data = c.itemsWithSubnet.Get(k)
continue
}
// Shift or renew bitmask.
if m%8 == 0 {
bitmask = ^byte(0)
} else {
bitmask <<= 1
}
// Clear the last non-zero bit in the byte of the IP address.
k[keyIPIndex+m/8] &= bitmask
data = c.itemsWithSubnet.Get(k)
}
if data == nil {
return nil, false, k
}
if ci, expired = c.unpackItem(data, req); ci == nil {
c.itemsWithSubnet.Del(k)
}
return ci, expired, k
}
// canLookUpInCache returns true if these parameters could be used to make a
// cache lookup.
func canLookUpInCache(cache glcache.Cache, req *dns.Msg) (ok bool) {
return cache != nil && req != nil && len(req.Question) == 1
}
// createCache returns new Cache with the given cacheSize.
func createCache(cacheSize int) (glc glcache.Cache) {
conf := glcache.Config{
MaxSize: defaultCacheSize,
EnableLRU: true,
}
if cacheSize > 0 {
conf.MaxSize = uint(cacheSize)
}
return glcache.New(conf)
}
// set stores response and upstream in the cache. l must not be nil.
func (c *cache) set(m *dns.Msg, u upstream.Upstream, l *slog.Logger) {
item := c.respToItem(m, u, l)
if item == nil {
return
}
key := msgToKey(m)
packed := item.pack()
c.itemsLock.Lock()
defer c.itemsLock.Unlock()
c.items.Set(key, packed)
}
// setWithSubnet stores response and upstream with subnet in the cache. The
// given subnet mask and IP address are used to calculate the cache key. l must
// not be nil.
func (c *cache) setWithSubnet(m *dns.Msg, u upstream.Upstream, subnet *net.IPNet, l *slog.Logger) {
item := c.respToItem(m, u, l)
if item == nil {
return
}
pref, _ := subnet.Mask.Size()
key := msgToKeyWithSubnet(m, subnet.IP.Mask(subnet.Mask), pref)
packed := item.pack()
c.itemsWithSubnetLock.Lock()
defer c.itemsWithSubnetLock.Unlock()
c.itemsWithSubnet.Set(key, packed)
}
// clearItems empties the simple cache.
func (c *cache) clearItems() {
c.itemsLock.Lock()
defer c.itemsLock.Unlock()
c.items.Clear()
}
// clearItemsWithSubnet empties the subnet cache, if any.
func (c *cache) clearItemsWithSubnet() {
if c.itemsWithSubnet == nil {
// ECS disabled, return immediately.
return
}
c.itemsWithSubnetLock.Lock()
defer c.itemsWithSubnetLock.Unlock()
c.itemsWithSubnet.Clear()
}
// cacheTTL returns the number of seconds for which m is valid to be cached.
// For negative answers it follows RFC 2308 on how to cache NXDOMAIN and NODATA
// kinds of responses. l must not be nil.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-2.1,
// https://datatracker.ietf.org/doc/html/rfc2308#section-2.2.
func cacheTTL(m *dns.Msg, l *slog.Logger) (ttl uint32) {
switch {
case m == nil:
return 0
case m.Truncated:
l.Debug("truncated message; not caching")
return 0
case len(m.Question) != 1:
l.Debug("message with wrong number of questions; not caching")
return 0
default:
ttl = calculateTTL(m)
if ttl == 0 {
l.Debug("ttl calculated to be 0; not caching")
return 0
}
}
switch rcode := m.Rcode; rcode {
case dns.RcodeSuccess:
if isCacheableSucceded(m) {
return ttl
}
l.Debug("not a cacheable noerror response; not caching")
case dns.RcodeNameError:
if isCacheableNegative(m) {
return ttl
}
l.Debug("not a cacheable nxdomain response; not caching")
case dns.RcodeServerFailure:
return ttl
default:
l.Debug("response code %s; not caching", "rcode", dns.RcodeToString[rcode])
}
return 0
}
// hasIPAns check the m for containing at least one A or AAAA RR in answer
// section.
func hasIPAns(m *dns.Msg) (ok bool) {
for _, rr := range m.Answer {
if t := rr.Header().Rrtype; t == dns.TypeA || t == dns.TypeAAAA {
return true
}
}
return false
}
// isCacheableSucceded returns true if m contains useful data to be cached
// treating it as a successful response.
func isCacheableSucceded(m *dns.Msg) (ok bool) {
qType := m.Question[0].Qtype
return (qType != dns.TypeA && qType != dns.TypeAAAA) || hasIPAns(m) || isCacheableNegative(m)
}
// isCacheableNegative returns true if m's header has at least a single SOA RR
// and no NS records so that it can be declared authoritative.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-5 for the
// information on the responses from the authoritative server that should be
// cached by the forwarder.
func isCacheableNegative(m *dns.Msg) (ok bool) {
for _, rr := range m.Ns {
switch rr.Header().Rrtype {
case dns.TypeSOA:
ok = true
case dns.TypeNS:
return false
default:
// Go on.
}
}
return ok
}
// ServFailMaxCacheTTL is the maximum time-to-live value for caching
// SERVFAIL responses in seconds. It's consistent with the upper constraint
// of 5 minutes given by RFC 2308.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-7.1.
const ServFailMaxCacheTTL = 30
// calculateTTL returns the number of seconds for which m could be cached. It's
// usually the lowest TTL among all m's resource records. It returns 0 if m
// isn't cacheable according to it's contents.
func calculateTTL(m *dns.Msg) (ttl uint32) {
// Use the maximum value as a guard value. If the inner loop is entered,
// it's going to be rewritten with an actual TTL value that is lower than
// MaxUint32. If the inner loop isn't entered, catch that and return zero.
ttl = math.MaxUint32
for _, rrset := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range rrset {
ttl = minTTL(rr.Header(), ttl)
if ttl == 0 {
return 0
}
}
}
switch {
case m.Rcode == dns.RcodeServerFailure && ttl > ServFailMaxCacheTTL:
return ServFailMaxCacheTTL
case ttl == math.MaxUint32:
return 0
default:
return ttl
}
}
// minTTL returns the minimum of h's ttl and the passed ttl.
func minTTL(h *dns.RR_Header, ttl uint32) uint32 {
switch {
case h.Rrtype == dns.TypeOPT:
return ttl
case h.Ttl < ttl:
return h.Ttl
default:
return ttl
}
}
// Updates a given TTL to fall within the range specified by the cacheMinTTL and
// cacheMaxTTL settings.
func respectTTLOverrides(ttl, cacheMinTTL, cacheMaxTTL uint32) uint32 {
if ttl < cacheMinTTL {
return cacheMinTTL
}
if cacheMaxTTL != 0 && ttl > cacheMaxTTL {
return cacheMaxTTL
}
return ttl
}
// msgToKey constructs the cache key from type, class and question's name of m.
func msgToKey(m *dns.Msg) (b []byte) {
q := m.Question[0]
name := q.Name
b = make([]byte, packedMsgLenSz+packedMsgLenSz+len(name))
// Put QTYPE, QCLASS, and QNAME.
binary.BigEndian.PutUint16(b, q.Qtype)
binary.BigEndian.PutUint16(b[packedMsgLenSz:], q.Qclass)
copy(b[2*packedMsgLenSz:], strings.ToLower(name))
return b
}
const (
// keyMaskIndex is the index of the byte with mask ones value.
keyMaskIndex = 1 + 2*packedMsgLenSz
// keyIPIndex is the start index of the IP address in the key.
keyIPIndex = keyMaskIndex + 1
)
// msgToKeyWithSubnet constructs the cache key from DO bit, type, class, subnet
// mask, client's IP address and question's name of m. ecsIP is expected to be
// masked already.
func msgToKeyWithSubnet(m *dns.Msg, ecsIP net.IP, mask int) (key []byte) {
q := m.Question[0]
keyLen := keyIPIndex + len(q.Name)
masked := mask != 0
if masked {
keyLen += len(ecsIP)
}
// Initialize the slice.
key = make([]byte, keyLen)
// Put DO.
opt := m.IsEdns0()
key[0] = mathutil.BoolToNumber[byte](opt != nil && opt.Do())
// Put Qtype.
//
// TODO(d.kolyshev): We should put Qtype in key[1:].
binary.BigEndian.PutUint16(key[:], q.Qtype)
// Put Qclass.
binary.BigEndian.PutUint16(key[1+packedMsgLenSz:], q.Qclass)
// Add mask.
key[keyMaskIndex] = uint8(mask)
k := keyIPIndex
if masked {
k += copy(key[keyIPIndex:], ecsIP)
}
copy(key[k:], strings.ToLower(q.Name))
return key
}
// isDNSSEC returns true if r is a DNSSEC RR. NSEC, NSEC3, DS, DNSKEY and
// RRSIG/SIG are DNSSEC records.
func isDNSSEC(r dns.RR) bool {
switch r.Header().Rrtype {
case
dns.TypeNSEC,
dns.TypeNSEC3,
dns.TypeDS,
dns.TypeRRSIG,
dns.TypeSIG,
dns.TypeDNSKEY:
return true
default:
return false
}
}
// filterRRSlice removes OPT RRs, DNSSEC RRs except the specified type if do is
// false, sets TTL if ttl is not equal to zero and returns the copy of the rrs.
// The except parameter defines RR of which type should not be filtered out.
func filterRRSlice(rrs []dns.RR, do bool, ttl uint32, except uint16) (filtered []dns.RR) {
rrsLen := len(rrs)
if rrsLen == 0 {
return nil
}
j := 0
rs := make([]dns.RR, rrsLen)
for _, r := range rrs {
if (!do && isDNSSEC(r) && r.Header().Rrtype != except) || r.Header().Rrtype == dns.TypeOPT {
continue
}
if ttl != 0 {
r.Header().Ttl = ttl
}
rs[j] = dns.Copy(r)
j++
}
return rs[:j]
}
// filterMsg removes OPT RRs, DNSSEC RRs if do is false, sets TTL to ttl if it's
// not equal to 0 and puts the results to appropriate fields of dst. It also
// filters the AD bit if both ad and do are false.
func filterMsg(dst, m *dns.Msg, ad, do bool, ttl uint32) {
// As RFC 6840 says, validating resolvers should only set the AD bit when a
// response both meets the conditions listed in RFC 4035, and the request
// contained either a set DO bit or a set AD bit.
dst.AuthenticatedData = dst.AuthenticatedData && (ad || do)
// It's important to filter out only DNSSEC RRs that aren't explicitly
// requested.
//
// See https://datatracker.ietf.org/doc/html/rfc4035#section-3.2.1 and
// https://github.com/AdguardTeam/dnsproxy/issues/144.
dst.Answer = filterRRSlice(m.Answer, do, ttl, m.Question[0].Qtype)
dst.Ns = filterRRSlice(m.Ns, do, ttl, dns.TypeNone)
dst.Extra = filterRRSlice(m.Extra, do, ttl, dns.TypeNone)
}
07070100000056000081A4000000000000000000000001679A649F00005EB9000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/cache_test.gopackage proxy
import (
"context"
"net"
"net/netip"
"strings"
"sync"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/miekg/dns"
)
// testCacheSize is the maximum size of cache for tests.
const testCacheSize = 4096
const testUpsAddr = "https://upstream.address"
var upstreamWithAddr = &fakeUpstream{
onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) { panic("not implemented") },
onClose: func() (err error) { panic("not implemented") },
onAddress: func() (addr string) { return testUpsAddr },
}
func TestServeCached(t *testing.T) {
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
CacheEnabled: true,
})
// Start listening.
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Fill the cache.
reply := (&dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
},
Answer: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
}).SetQuestion("google.com.", dns.TypeA)
reply.SetEdns0(defaultUDPBufSize, false)
dnsProxy.cache.set(reply, upstreamWithAddr, slogutil.NewDiscardLogger())
// Create a DNS-over-UDP client connection.
addr := dnsProxy.Addr(ProtoUDP)
client := &dns.Client{
Net: string(ProtoUDP),
Timeout: testTimeout,
}
// Create a DNS request.
request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
request.SetEdns0(defaultUDPBufSize, false)
r, _, err := client.Exchange(request, addr.String())
require.NoErrorf(t, err, "error in the first request: %s", err)
requireEqualMsgs(t, r, reply)
}
func TestCache_expired(t *testing.T) {
const host = "google.com."
ans := &dns.A{
Hdr: dns.RR_Header{
Name: host,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.IP{8, 8, 8, 8},
}
reply := (&dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
},
Answer: []dns.RR{ans},
}).SetQuestion(host, dns.TypeA)
testCases := []struct {
name string
ttl uint32
wantTTL uint32
optimistic bool
}{{
name: "realistic_hit",
ttl: defaultTestTTL,
wantTTL: defaultTestTTL,
optimistic: false,
}, {
name: "realistic_miss",
ttl: 0,
wantTTL: 0,
optimistic: false,
}, {
name: "optimistic_hit",
ttl: defaultTestTTL,
wantTTL: defaultTestTTL,
optimistic: true,
}, {
name: "optimistic_expired",
ttl: 0,
wantTTL: optimisticTTL,
optimistic: true,
}}
testCache := newCache(testCacheSize, false, false)
for _, tc := range testCases {
ans.Hdr.Ttl = tc.ttl
req := (&dns.Msg{}).SetQuestion(host, dns.TypeA)
t.Run(tc.name, func(t *testing.T) {
if tc.optimistic {
testCache.optimistic = true
t.Cleanup(func() { testCache.optimistic = false })
}
key := msgToKey(reply)
data := (&cacheItem{
m: reply,
u: testUpsAddr,
ttl: tc.ttl,
}).pack()
testCache.items.Set(key, data)
t.Cleanup(testCache.items.Clear)
r, expired, key := testCache.get(req)
assert.Equal(t, msgToKey(req), key)
assert.Equal(t, tc.ttl == 0, expired)
if tc.wantTTL != 0 {
require.NotNil(t, r)
assert.Equal(t, tc.wantTTL, r.m.Answer[0].Header().Ttl)
assert.Equal(t, testUpsAddr, r.u)
} else {
require.Nil(t, r)
}
})
}
}
func TestCacheDO(t *testing.T) {
testCache := newCache(testCacheSize, false, false)
// Fill the cache.
reply := (&dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
},
Answer: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
}).SetQuestion("google.com.", dns.TypeA)
reply.SetEdns0(4096, true)
// Store in cache.
testCache.set(reply, upstreamWithAddr, slogutil.NewDiscardLogger())
// Make a request.
request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
t.Run("without_do", func(t *testing.T) {
ci, expired, key := testCache.get(request)
assert.False(t, expired)
assert.Equal(t, msgToKey(request), key)
assert.NotNil(t, ci)
})
t.Run("with_do", func(t *testing.T) {
reqClone := request.Copy()
t.Cleanup(func() {
request = reqClone
})
request.SetEdns0(4096, true)
ci, expired, key := testCache.get(request)
assert.False(t, expired)
assert.Equal(t, msgToKey(request), key)
require.NotNil(t, ci)
assert.Equal(t, testUpsAddr, ci.u)
})
}
func TestCacheCNAME(t *testing.T) {
l := slogutil.NewDiscardLogger()
testCache := newCache(testCacheSize, false, false)
// Fill the cache
reply := (&dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
},
Answer: []dns.RR{newRR(t, "google.com.", dns.TypeCNAME, 3600, "test.google.com.")},
}).SetQuestion("google.com.", dns.TypeA)
testCache.set(reply, upstreamWithAddr, l)
// Create a DNS request.
request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
t.Run("no_cnames", func(t *testing.T) {
r, expired, _ := testCache.get(request)
assert.Nil(t, r)
assert.False(t, expired)
})
// Now fill the cache with a cacheable CNAME response.
reply.Answer = append(reply.Answer, newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8}))
testCache.set(reply, upstreamWithAddr, l)
// We are testing that a proper CNAME response gets cached
t.Run("cnames_exist", func(t *testing.T) {
r, expired, key := testCache.get(request)
assert.False(t, expired)
assert.Equal(t, key, msgToKey(request))
require.NotNil(t, r)
assert.Equal(t, testUpsAddr, r.u)
})
}
func TestCache_uncacheable(t *testing.T) {
testCache := newCache(testCacheSize, false, false)
// Create a DNS request.
request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
// Fill the cache.
reply := (&dns.Msg{}).SetRcode(request, dns.RcodeBadAlg)
// We are testing that SERVFAIL responses aren't cached
testCache.set(reply, upstreamWithAddr, slogutil.NewDiscardLogger())
r, expired, _ := testCache.get(request)
assert.Nil(t, r)
assert.False(t, expired)
}
func TestCache_concurrent(t *testing.T) {
testCache := newCache(testCacheSize, false, false)
hosts := map[string]string{
dns.Fqdn("yandex.com"): "213.180.204.62",
dns.Fqdn("google.com"): "8.8.8.8",
dns.Fqdn("www.google.com"): "8.8.4.4",
dns.Fqdn("youtube.com"): "173.194.221.198",
dns.Fqdn("car.ru"): "37.220.161.35",
dns.Fqdn("cat.ru"): "192.56.231.67",
}
g := &sync.WaitGroup{}
g.Add(len(hosts))
for k, v := range hosts {
go setAndGetCache(t, testCache, g, k, v)
}
g.Wait()
}
const (
// cacheTick is a cache check period.
cacheTick = 100 * time.Millisecond
// cacheTimeout is the timeout of cache check.
cacheTimeout = 20 * cacheTick
)
func TestCacheExpiration(t *testing.T) {
t.Parallel()
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
CacheEnabled: true,
})
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
l := slogutil.NewDiscardLogger()
// Create dns messages with TTL of 1 second.
rrs := []dns.RR{
newRR(t, "youtube.com.", dns.TypeA, 1, net.IP{173, 194, 221, 198}),
newRR(t, "google.com.", dns.TypeA, 1, net.IP{8, 8, 8, 8}),
newRR(t, "yandex.com.", dns.TypeA, 1, net.IP{213, 180, 204, 62}),
}
replies := make([]*dns.Msg, len(rrs))
for i, rr := range rrs {
rep := (&dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
},
Answer: []dns.RR{dns.Copy(rr)},
}).SetQuestion(rr.Header().Name, dns.TypeA)
dnsProxy.cache.set(rep, upstreamWithAddr, l)
replies[i] = rep
}
for _, r := range replies {
ci, expired, key := dnsProxy.cache.get(r)
require.NotNil(t, ci)
assert.False(t, expired)
assert.Equal(t, msgToKey(ci.m), key)
requireEqualMsgs(t, ci.m, r)
}
assert.Eventually(t, func() bool {
for _, r := range replies {
if ci, _, _ := dnsProxy.cache.get(r); ci != nil {
return false
}
}
return true
}, cacheTimeout, cacheTick)
}
func TestCacheExpirationWithTTLOverride(t *testing.T) {
u := testUpstream{}
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{&u},
},
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
CacheEnabled: true,
CacheMinTTL: 20,
CacheMaxTTL: 40,
})
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
d := &DNSContext{}
t.Run("replace_min", func(t *testing.T) {
d.Req = newHostTestMessage("host")
d.Addr = netip.AddrPort{}
u.ans = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Rrtype: dns.TypeA,
Name: "host.",
Ttl: 10,
},
A: net.IP{4, 3, 2, 1},
}}
err = dnsProxy.Resolve(d)
require.NoError(t, err)
ci, expired, key := dnsProxy.cache.get(d.Req)
assert.False(t, expired)
assert.Equal(t, msgToKey(d.Req), key)
require.NotNil(t, ci)
assert.Equal(t, dnsProxy.CacheMinTTL, ci.m.Answer[0].Header().Ttl)
})
t.Run("replace_max", func(t *testing.T) {
d.Req = newHostTestMessage("host2")
d.Addr = netip.AddrPort{}
u.ans = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Rrtype: dns.TypeA,
Name: "host2.",
Ttl: 60,
},
A: net.IP{4, 3, 2, 1},
}}
err = dnsProxy.Resolve(d)
assert.Nil(t, err)
ci, expired, key := dnsProxy.cache.get(d.Req)
assert.False(t, expired)
assert.Equal(t, msgToKey(d.Req), key)
require.NotNil(t, ci)
assert.Equal(t, dnsProxy.CacheMaxTTL, ci.m.Answer[0].Header().Ttl)
})
}
type testEntry struct {
q string
a []dns.RR
t uint16
}
type testCase struct {
ok require.BoolAssertionFunc
q string
a []dns.RR
t uint16
}
type testCases struct {
cache []testEntry
cases []testCase
}
func TestCache(t *testing.T) {
t.Run("simple", func(t *testing.T) {
testCases{
cache: []testEntry{{
q: "google.com.",
a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
t: dns.TypeA,
}},
cases: []testCase{{
ok: require.True,
q: "google.com.",
a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
t: dns.TypeA,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
}},
}.run(t)
})
t.Run("mixed_case", func(t *testing.T) {
testCases{
cache: []testEntry{{
q: "gOOgle.com.",
a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
t: dns.TypeA,
}},
cases: []testCase{{
ok: require.True,
q: "gOOgle.com.",
a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
t: dns.TypeA,
}, {
ok: require.True,
q: "google.com.",
a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
t: dns.TypeA,
}, {
ok: require.True,
q: "GOOGLE.COM.",
a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
t: dns.TypeA,
}, {
q: "gOOgle.com.",
t: dns.TypeMX,
ok: require.False,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
}, {
ok: require.False,
q: "GOOGLE.COM.",
t: dns.TypeMX,
}},
}.run(t)
})
t.Run("zero_ttl", func(t *testing.T) {
testCases{
cache: []testEntry{{
q: "gOOgle.com.",
a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 0, net.IP{8, 8, 8, 8})},
t: dns.TypeA,
}},
cases: []testCase{{
ok: require.False,
q: "google.com.",
t: dns.TypeA,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeA,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeA,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
}},
}.run(t)
})
}
func (tests testCases) run(t *testing.T) {
l := slogutil.NewDiscardLogger()
testCache := newCache(testCacheSize, false, false)
for _, res := range tests.cache {
reply := (&dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
},
Answer: res.a,
}).SetQuestion(res.q, res.t)
testCache.set(reply, upstreamWithAddr, l)
}
for _, tc := range tests.cases {
request := (&dns.Msg{}).SetQuestion(tc.q, tc.t)
ci, expired, _ := testCache.get(request)
assert.False(t, expired)
tc.ok(t, ci != nil)
if tc.a == nil {
return
} else if ci == nil {
continue
}
reply := (&dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
},
Answer: tc.a,
}).SetQuestion(tc.q, tc.t)
testCache.set(reply, upstreamWithAddr, l)
requireEqualMsgs(t, ci.m, reply)
}
}
// requireEqualMsgs asserts the messages are equal except their ID, Rdlength, and
// the case of questions.
func requireEqualMsgs(t *testing.T, expected, actual *dns.Msg) {
t.Helper()
temp := *expected
temp.Id = actual.Id
require.Equal(t, len(temp.Answer), len(actual.Answer))
for i, ans := range actual.Answer {
temp.Answer[i].Header().Rdlength = ans.Header().Rdlength
}
for _, rr := range actual.Answer {
if a, ok := rr.(*dns.A); ok {
if a4 := a.A.To4(); a4 != nil {
a.A = a4
}
}
}
for i := range temp.Question {
temp.Question[i].Name = strings.ToLower(temp.Question[i].Name)
}
for i := range actual.Question {
actual.Question[i].Name = strings.ToLower(actual.Question[i].Name)
}
assert.Equal(t, &temp, actual)
}
func setAndGetCache(t *testing.T, c *cache, g *sync.WaitGroup, host, ip string) {
defer g.Done()
ipAddr := net.ParseIP(ip)
dnsMsg := (&dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
},
Answer: []dns.RR{newRR(t, host, dns.TypeA, 1, ipAddr)},
}).SetQuestion(host, dns.TypeA)
c.set(dnsMsg, upstreamWithAddr, slogutil.NewDiscardLogger())
for range 2 {
ci, expired, key := c.get(dnsMsg)
require.NotNilf(t, ci, "no cache found for %s", host)
assert.False(t, expired)
assert.Equal(t, msgToKey(dnsMsg), key)
requireEqualMsgs(t, ci.m, dnsMsg)
}
assert.Eventuallyf(t, func() bool {
ci, _, _ := c.get(dnsMsg)
return ci == nil
}, cacheTimeout, cacheTick, "cache for %s should already be removed", host)
}
func TestCache_getWithSubnet(t *testing.T) {
const testFQDN = "example.com."
ip1234, ip2234, ip3234 := net.IP{1, 2, 3, 4}, net.IP{2, 2, 3, 4}, net.IP{3, 2, 3, 4}
req := (&dns.Msg{}).SetQuestion(testFQDN, dns.TypeA)
mask16 := net.CIDRMask(16, netutil.IPv4BitLen)
mask24 := net.CIDRMask(24, netutil.IPv4BitLen)
l := slogutil.NewDiscardLogger()
c := newCache(testCacheSize, true, false)
t.Run("empty", func(t *testing.T) {
ci, expired, _ := c.getWithSubnet(req, &net.IPNet{IP: ip1234, Mask: mask24})
assert.Nil(t, ci)
assert.False(t, expired)
})
// Add a response with subnet.
resp := (&dns.Msg{
Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 1, net.IP{1, 1, 1, 1})},
}).SetReply(req)
c.setWithSubnet(resp, upstreamWithAddr, &net.IPNet{IP: ip1234, Mask: mask16}, slogutil.NewDiscardLogger())
t.Run("different_ip", func(t *testing.T) {
ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip2234, Mask: mask24})
assert.False(t, expired)
assert.Equal(t, msgToKeyWithSubnet(req, ip2234, 0), key)
assert.Nil(t, ci)
})
// Add a response entry with subnet #2.
resp = (&dns.Msg{
Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 1, net.IP{2, 2, 2, 2})},
}).SetReply(req)
c.setWithSubnet(resp, upstreamWithAddr, &net.IPNet{IP: ip2234, Mask: mask16}, l)
// Add a response entry without subnet.
resp = (&dns.Msg{
Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 1, net.IP{3, 3, 3, 3})},
}).SetReply(req)
c.setWithSubnet(resp, upstreamWithAddr, &net.IPNet{IP: nil, Mask: nil}, l)
t.Run("with_subnet_1", func(t *testing.T) {
ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip1234, Mask: mask24})
assert.False(t, expired)
assert.Equal(t, msgToKeyWithSubnet(req, ip1234.Mask(mask16), 16), key)
require.NotNil(t, ci)
require.NotNil(t, ci.m)
require.NotEmpty(t, ci.m.Answer)
a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
assert.True(t, a.A.Equal(net.IP{1, 1, 1, 1}))
})
t.Run("with_subnet_2", func(t *testing.T) {
ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip2234, Mask: mask24})
assert.False(t, expired)
assert.Equal(t, msgToKeyWithSubnet(req, ip2234.Mask(mask16), 16), key)
require.NotNil(t, ci)
require.NotNil(t, ci.m)
require.NotEmpty(t, ci.m.Answer)
a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
assert.True(t, a.A.Equal(net.IP{2, 2, 2, 2}))
})
t.Run("with_subnet_3", func(t *testing.T) {
ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip3234, Mask: mask24})
assert.False(t, expired)
assert.Equal(t, msgToKeyWithSubnet(req, ip1234, 0), key)
require.NotNil(t, ci)
require.NotNil(t, ci.m)
require.NotEmpty(t, ci.m.Answer)
a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
assert.True(t, a.A.Equal(net.IP{3, 3, 3, 3}))
})
}
func TestCache_getWithSubnet_mask(t *testing.T) {
const testFQDN = "example.com."
testIP := net.IP{176, 112, 191, 0}
noMatchIP := net.IP{177, 112, 191, 0}
// cachedIP/cidrMask network contains the testIP.
const cidrMaskOnes = 20
cidrMask := net.CIDRMask(cidrMaskOnes, netutil.IPv4BitLen)
cachedIP := net.IP{176, 112, 176, 0}
ansIP := net.IP{4, 4, 4, 4}
c := newCache(testCacheSize, true, true)
req := (&dns.Msg{}).SetQuestion(testFQDN, dns.TypeA)
resp := (&dns.Msg{
Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 300, ansIP)},
}).SetReply(req)
// Cache IP network that contains the testIP.
c.setWithSubnet(
resp,
upstreamWithAddr,
&net.IPNet{IP: cachedIP, Mask: cidrMask},
slogutil.NewDiscardLogger(),
)
t.Run("mask_matched", func(t *testing.T) {
ci, expired, key := c.getWithSubnet(req, &net.IPNet{
IP: testIP,
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
})
assert.False(t, expired)
assert.Equal(t, msgToKeyWithSubnet(req, testIP.Mask(cidrMask), cidrMaskOnes), key)
require.NotNil(t, ci)
require.NotNil(t, ci.m)
require.NotEmpty(t, ci.m.Answer)
a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
assert.True(t, a.A.Equal(ansIP))
})
t.Run("no_mask_matched", func(t *testing.T) {
ci, expired, key := c.getWithSubnet(req, &net.IPNet{
IP: noMatchIP,
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
})
assert.False(t, expired)
assert.Equal(t, msgToKeyWithSubnet(req, noMatchIP, 0), key)
assert.Nil(t, ci)
})
}
func TestCache_IsCacheable_negative(t *testing.T) {
const someTTL = 3600
msgHdr := func(rcode int) (hdr dns.MsgHdr) { return dns.MsgHdr{Id: dns.Id(), Rcode: rcode} }
aQuestions := func(name string) []dns.Question {
return []dns.Question{{
Name: name,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}}
}
cnameAns := func(name, cname string) (rr dns.RR) {
return &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: someTTL,
},
Target: cname,
}
}
soaAns := func(name, ns, mbox string) (rr dns.RR) {
return &dns.SOA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: someTTL,
},
Ns: ns,
Mbox: mbox,
}
}
nsAns := func(name, ns string) (rr dns.RR) {
return &dns.NS{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: someTTL,
},
Ns: ns,
}
}
aAns := func(name string, a net.IP) (rr dns.RR) {
return &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: someTTL,
},
A: a,
}
}
const (
hostname = "AN.EXAMPLE."
anotherHostname = "ANOTHER.EXAMPLE."
cname = "TRIPPLE.XX."
mbox = "HOSTMASTER.NS1.XX."
ns1, ns2 = "NS1.XX.", "NS2.XX."
xx = "XX."
)
// See https://datatracker.ietf.org/doc/html/rfc2308.
testCases := []struct {
req *dns.Msg
name string
wantTTL uint32
}{{
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeNameError),
Question: aQuestions(hostname),
Answer: []dns.RR{cnameAns(hostname, cname)},
Ns: []dns.RR{
soaAns(xx, ns1, mbox),
nsAns(xx, ns1),
nsAns(xx, ns2),
},
Extra: []dns.RR{
aAns(ns1, net.IP{127, 0, 0, 2}),
aAns(ns2, net.IP{127, 0, 0, 3}),
},
},
name: "rfc2308_nxdomain_response_type_1",
wantTTL: 0,
}, {
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeNameError),
Question: aQuestions(hostname),
Answer: []dns.RR{cnameAns(hostname, cname)},
Ns: []dns.RR{soaAns("XX.", ns1, mbox)},
},
name: "rfc2308_nxdomain_response_type_2",
wantTTL: someTTL,
}, {
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeNameError),
Question: aQuestions(hostname),
Answer: []dns.RR{cnameAns(hostname, cname)},
},
name: "rfc2308_nxdomain_response_type_3",
wantTTL: 0,
}, {
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeNameError),
Question: aQuestions(hostname),
Answer: []dns.RR{cnameAns(hostname, cname)},
Ns: []dns.RR{
nsAns(xx, ns1),
nsAns(xx, ns2),
},
Extra: []dns.RR{
aAns(ns1, net.IP{127, 0, 0, 2}),
aAns(ns2, net.IP{127, 0, 0, 3}),
},
},
name: "rfc2308_nxdomain_response_type_4",
wantTTL: 0,
}, {
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeSuccess),
Question: aQuestions(hostname),
Answer: []dns.RR{cnameAns(hostname, cname)},
Ns: []dns.RR{
nsAns(xx, ns1),
nsAns(xx, ns2),
},
Extra: []dns.RR{
aAns(ns1, net.IP{127, 0, 0, 2}),
aAns(ns2, net.IP{127, 0, 0, 3}),
},
},
name: "rfc2308_nxdomain_referral_response",
wantTTL: 0,
}, {
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeSuccess),
Question: aQuestions(anotherHostname),
Ns: []dns.RR{
soaAns(xx, ns1, mbox),
nsAns(xx, ns1),
nsAns(xx, ns2),
},
Extra: []dns.RR{
aAns(ns1, net.IP{127, 0, 0, 2}),
aAns(ns2, net.IP{127, 0, 0, 3}),
},
},
name: "rfc2308_nodata_response_type_1",
wantTTL: 0,
}, {
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeSuccess),
Question: aQuestions(anotherHostname),
Ns: []dns.RR{soaAns(xx, ns1, mbox)},
},
name: "rfc2308_nodata_response_type_2",
wantTTL: someTTL,
}, {
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeSuccess),
Question: aQuestions(anotherHostname),
},
name: "rfc2308_nodata_response_type_3",
wantTTL: 0,
}, {
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeSuccess),
Question: aQuestions(anotherHostname),
Ns: []dns.RR{
nsAns(xx, ns1),
nsAns(xx, ns2),
},
Extra: []dns.RR{
aAns(ns1, net.IP{127, 0, 0, 2}),
aAns(ns2, net.IP{127, 0, 0, 3}),
},
},
name: "rfc2308_nodata_referral_response",
wantTTL: 0,
}, {
req: &dns.Msg{
MsgHdr: msgHdr(dns.RcodeServerFailure),
Question: aQuestions(anotherHostname),
},
name: "servfail_response",
wantTTL: ServFailMaxCacheTTL,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.wantTTL, cacheTTL(tc.req, slogutil.NewDiscardLogger()))
})
}
}
07070100000057000081A4000000000000000000000001679A649F000001F3000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/clock.gopackage proxy
import "time"
// clock is the interface for provider of current time. It's used to simplify
// testing.
//
// TODO(e.burkov): Move to golibs.
type clock interface {
// Now returns the current local time.
Now() (now time.Time)
}
// type check
var _ clock = realClock{}
// realClock is the [clock] which actually uses the [time] package.
type realClock struct{}
// Now implements the [clock] interface for RealClock.
func (realClock) Now() (now time.Time) { return time.Now() }
07070100000058000081A4000000000000000000000001679A649F000034B5000000000000000000000000000000000000002000000000dnsproxy-0.75.0/proxy/config.gopackage proxy
import (
"crypto/tls"
"fmt"
"log/slog"
"net"
"net/netip"
"net/url"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/ameshkov/dnscrypt/v2"
)
// LogPrefix is a prefix for logging.
const LogPrefix = "dnsproxy"
// RequestHandler is an optional custom handler for DNS requests. It's used
// instead of [Proxy.Resolve] if set. The resulting error doesn't affect the
// request processing. The custom handler is responsible for calling
// [ResponseHandler], if it doesn't call [Proxy.Resolve].
//
// TODO(e.burkov): Use the same interface-based approach as
// [BeforeRequestHandler].
type RequestHandler func(p *Proxy, dctx *DNSContext) (err error)
// ResponseHandler is an optional custom handler called when DNS query has been
// processed. When called from [Proxy.Resolve], dctx will contain the response
// message if the upstream or cache succeeded. err is only not nil if the
// upstream failed to respond.
//
// TODO(e.burkov): Use the same interface-based approach as
// [BeforeRequestHandler].
type ResponseHandler func(dctx *DNSContext, err error)
// Config contains all the fields necessary for proxy configuration.
//
// TODO(a.garipov): Consider extracting conf blocks for better fieldalignment.
type Config struct {
// Logger is used as the base logger for the proxy service. If nil,
// [slog.Default] with [LogPrefix] is used.
Logger *slog.Logger
// TrustedProxies is the trusted list of CIDR networks to detect proxy
// servers addresses from where the DoH requests should be handled. The
// value of nil makes Proxy not trust any address.
TrustedProxies netutil.SubnetSet
// PrivateSubnets is the set of private networks. Client having an address
// within this set is able to resolve PTR requests for addresses within this
// set.
PrivateSubnets netutil.SubnetSet
// MessageConstructor used to build DNS messages. If nil, the default
// constructor will be used.
MessageConstructor MessageConstructor
// BeforeRequestHandler is an optional custom handler called before each DNS
// request is started processing, see [BeforeRequestHandler]. The default
// no-op implementation is used, if it's nil.
BeforeRequestHandler BeforeRequestHandler
// RequestHandler is an optional custom handler for DNS requests. It's used
// instead of [Proxy.Resolve] if set. See [RequestHandler].
RequestHandler RequestHandler
// ResponseHandler is an optional custom handler called when DNS query has
// been processed. See [ResponseHandler].
ResponseHandler ResponseHandler
// UpstreamConfig is a general set of DNS servers to forward requests to.
UpstreamConfig *UpstreamConfig
// PrivateRDNSUpstreamConfig is the set of upstream DNS servers for
// resolving private IP addresses. All the requests considered private will
// be resolved via these upstream servers. Such queries will finish with
// [upstream.ErrNoUpstream] if it's empty.
PrivateRDNSUpstreamConfig *UpstreamConfig
// Fallbacks is a list of fallback resolvers. Those will be used if the
// general set fails responding.
Fallbacks *UpstreamConfig
// Userinfo is the sole permitted userinfo for the DoH basic authentication.
// If Userinfo is set, all DoH queries are required to have this basic
// authentication information.
Userinfo *url.Userinfo
// TLSConfig is the TLS configuration. Required for DNS-over-TLS,
// DNS-over-HTTP, and DNS-over-QUIC servers.
TLSConfig *tls.Config
// DNSCryptResolverCert is the DNSCrypt resolver certificate. Required for
// DNSCrypt server.
DNSCryptResolverCert *dnscrypt.Cert
// DNSCryptProviderName is the DNSCrypt provider name. Required for
// DNSCrypt server.
DNSCryptProviderName string
// HTTPSServerName sets the Server header of the HTTPS server responses, if
// not empty.
HTTPSServerName string
// UpstreamMode determines the logic through which upstreams will be used.
// If not specified the [proxy.UpstreamModeLoadBalance] is used.
UpstreamMode UpstreamMode
// UDPListenAddr is the set of UDP addresses to listen for plain
// DNS-over-UDP requests.
UDPListenAddr []*net.UDPAddr
// TCPListenAddr is the set of TCP addresses to listen for plain
// DNS-over-TCP requests.
TCPListenAddr []*net.TCPAddr
// HTTPSListenAddr is the set of TCP addresses to listen for DNS-over-HTTPS
// requests.
HTTPSListenAddr []*net.TCPAddr
// TLSListenAddr is the set of TCP addresses to listen for DNS-over-TLS
// requests.
TLSListenAddr []*net.TCPAddr
// QUICListenAddr is the set of UDP addresses to listen for DNS-over-QUIC
// requests.
QUICListenAddr []*net.UDPAddr
// DNSCryptUDPListenAddr is the set of UDP addresses to listen for DNSCrypt
// requests.
DNSCryptUDPListenAddr []*net.UDPAddr
// DNSCryptTCPListenAddr is the set of TCP addresses to listen for DNSCrypt
// requests.
DNSCryptTCPListenAddr []*net.TCPAddr
// BogusNXDomain is the set of networks used to transform responses into
// NXDOMAIN ones if they contain at least a single IP address within these
// networks. It's similar to dnsmasq's "bogus-nxdomain".
BogusNXDomain []netip.Prefix
// DNS64Prefs is the set of NAT64 prefixes used for DNS64 handling. nil
// value disables the feature. An empty value will be interpreted as the
// default Well-Known Prefix.
DNS64Prefs []netip.Prefix
// RatelimitWhitelist is a list of IP addresses excluded from rate limiting.
RatelimitWhitelist []netip.Addr
// EDNSAddr is the ECS IP used in request.
EDNSAddr net.IP
// TODO(s.chzhen): Extract ratelimit settings to a separate structure.
// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
// rate limiting requests.
RatelimitSubnetLenIPv4 int
// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
// rate limiting requests.
RatelimitSubnetLenIPv6 int
// Ratelimit is a maximum number of requests per second from a given IP (0
// to disable).
Ratelimit int
// CacheSizeBytes is the maximum cache size in bytes.
CacheSizeBytes int
// CacheMinTTL is the minimum TTL for cached DNS responses in seconds.
CacheMinTTL uint32
// CacheMaxTTL is the maximum TTL for cached DNS responses in seconds.
CacheMaxTTL uint32
// MaxGoroutines is the maximum number of goroutines processing DNS
// requests. Important for mobile users.
//
// TODO(a.garipov): Rename this to something like “MaxDNSRequestGoroutines”
// in a later major version, as it doesn't actually limit all goroutines.
MaxGoroutines uint
// The size of the read buffer on the underlying socket. Larger read
// buffers can handle larger bursts of requests before packets get dropped.
UDPBufferSize int
// FastestPingTimeout is the timeout for waiting the first successful
// dialing when the UpstreamMode is set to [UpstreamModeFastestAddr].
// Non-positive value will be replaced with the default one.
FastestPingTimeout time.Duration
// RefuseAny makes proxy refuse the requests of type ANY.
RefuseAny bool
// HTTP3 enables HTTP/3 support for HTTPS server.
HTTP3 bool
// Enable EDNS Client Subnet option DNS requests to the upstream server will
// contain an OPT record with Client Subnet option. If the original request
// already has this option set, we pass it through as is. Otherwise, we set
// it ourselves using the client IP with subnet /24 (for IPv4) and /56 (for
// IPv6).
//
// If the upstream server supports ECS, it sets subnet number in the
// response. This subnet number along with the client IP and other data is
// used as a cache key. Next time, if a client from the same subnet
// requests this host name, we get the response from cache. If another
// client from a different subnet requests this host name, we pass his
// request to the upstream server.
//
// If the upstream server doesn't support ECS (there's no subnet number in
// response), this response will be cached for all clients.
//
// If client IP is private (i.e. not public), we don't add EDNS record into
// a request. And so there will be no EDNS record in response either. We
// store these responses in general cache (without subnet) so they will
// never be used for clients with public IP addresses.
EnableEDNSClientSubnet bool
// CacheEnabled defines if the response cache should be used.
CacheEnabled bool
// CacheOptimistic defines if the optimistic cache mechanism should be used.
CacheOptimistic bool
// UseDNS64 enables DNS64 handling. If true, proxy will translate IPv4
// answers into IPv6 answers using first of DNS64Prefs. Note also that PTR
// requests for addresses within the specified networks are considered
// private and will be forwarded as PrivateRDNSUpstreamConfig specifies.
// Those will be responded with NXDOMAIN if UsePrivateRDNS is false.
UseDNS64 bool
// UsePrivateRDNS defines if the PTR requests for private IP addresses
// should be resolved via PrivateRDNSUpstreamConfig. Note that it requires
// a valid PrivateRDNSUpstreamConfig with at least a single general upstream
// server.
UsePrivateRDNS bool
// PreferIPv6 tells the proxy to prefer IPv6 addresses when bootstrapping
// upstreams that use hostnames.
PreferIPv6 bool
}
// validateConfig verifies that the supplied configuration is valid and returns
// an error if it's not.
//
// TODO(s.chzhen): Use [validate.Interface] from golibs.
func (p *Proxy) validateConfig() (err error) {
err = p.UpstreamConfig.validate()
if err != nil {
return fmt.Errorf("validating general upstreams: %w", err)
}
err = ValidatePrivateConfig(p.PrivateRDNSUpstreamConfig, p.privateNets)
if err != nil {
if p.UsePrivateRDNS || errors.Is(err, upstream.ErrNoUpstreams) {
return fmt.Errorf("validating private RDNS upstreams: %w", err)
}
}
// Allow [Proxy.Fallbacks] to be nil, but not empty. nil means not to use
// fallbacks at all.
err = p.Fallbacks.validate()
if errors.Is(err, upstream.ErrNoUpstreams) {
return fmt.Errorf("validating fallbacks: %w", err)
}
err = p.validateRatelimit()
if err != nil {
return fmt.Errorf("validating ratelimit: %w", err)
}
switch p.UpstreamMode {
case "":
// Go on.
case UpstreamModeFastestAddr, UpstreamModeLoadBalance, UpstreamModeParallel:
// Go on.
default:
return fmt.Errorf("bad upstream mode: %q", p.UpstreamMode)
}
p.logConfigInfo()
return nil
}
// validateRatelimit validates ratelimit configuration and returns an error if
// it's invalid.
func (p *Proxy) validateRatelimit() (err error) {
if p.Ratelimit == 0 {
return nil
}
err = checkInclusion(p.RatelimitSubnetLenIPv4, 0, netutil.IPv4BitLen)
if err != nil {
return fmt.Errorf("ratelimit subnet len ipv4 is invalid: %w", err)
}
err = checkInclusion(p.RatelimitSubnetLenIPv6, 0, netutil.IPv6BitLen)
if err != nil {
return fmt.Errorf("ratelimit subnet len ipv6 is invalid: %w", err)
}
return nil
}
// checkInclusion returns an error if a n is not in the inclusive range between
// minN and maxN.
func checkInclusion(n, minN, maxN int) (err error) {
switch {
case n < minN:
return fmt.Errorf("value %d less than min %d", n, minN)
case n > maxN:
return fmt.Errorf("value %d greater than max %d", n, maxN)
}
return nil
}
// logConfigInfo logs proxy configuration information.
func (p *Proxy) logConfigInfo() {
if p.CacheMinTTL > 0 || p.CacheMaxTTL > 0 {
p.logger.Info("cache ttl override is enabled", "min", p.CacheMinTTL, "max", p.CacheMaxTTL)
}
if p.Ratelimit > 0 {
p.logger.Info(
"ratelimit is enabled",
"rps",
p.Ratelimit,
"ipv4_subnet_mask_len",
p.RatelimitSubnetLenIPv4,
"ipv6_subnet_mask_len",
p.RatelimitSubnetLenIPv6,
)
}
if p.RefuseAny {
p.logger.Info("server will refuse requests of type any")
}
if len(p.BogusNXDomain) > 0 {
p.logger.Info("bogus-nxdomain ip specified", "prefix_len", len(p.BogusNXDomain))
}
if p.UpstreamMode != "" {
p.logger.Info("upstream mode is set", "mode", p.UpstreamMode)
}
}
// validateListenAddrs returns an error if the addresses are not configured
// properly.
func (p *Proxy) validateListenAddrs() (err error) {
if !p.hasListenAddrs() {
return errors.Error("no listen address specified")
}
err = p.validateTLSConfig()
if err != nil {
return fmt.Errorf("invalid tls configuration: %w", err)
}
if p.DNSCryptResolverCert == nil || p.DNSCryptProviderName == "" {
if p.DNSCryptTCPListenAddr != nil {
return errors.Error("cannot create dnscrypt tcp listener without dnscrypt config")
}
if p.DNSCryptUDPListenAddr != nil {
return errors.Error("cannot create dnscrypt udp listener without dnscrypt config")
}
}
return nil
}
// validateTLSConfig returns an error if proxy TLS configuration parameters are
// needed but aren't provided.
func (p *Proxy) validateTLSConfig() (err error) {
if p.TLSConfig != nil {
return nil
}
if p.TLSListenAddr != nil {
return errors.Error("tls listener configuration not found")
}
if p.HTTPSListenAddr != nil {
return errors.Error("https listener configuration not found")
}
if p.QUICListenAddr != nil {
return errors.Error("quic listener configuration not found")
}
return nil
}
// hasListenAddrs - is there any addresses to listen to?
func (p *Proxy) hasListenAddrs() bool {
return p.UDPListenAddr != nil ||
p.TCPListenAddr != nil ||
p.TLSListenAddr != nil ||
p.HTTPSListenAddr != nil ||
p.QUICListenAddr != nil ||
p.DNSCryptUDPListenAddr != nil ||
p.DNSCryptTCPListenAddr != nil
}
07070100000059000081A4000000000000000000000001679A649F000000AE000000000000000000000000000000000000002500000000dnsproxy-0.75.0/proxy/constructor.gopackage proxy
import (
"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
)
// MessageConstructor creates DNS messages.
type MessageConstructor = dnsmsg.MessageConstructor
0707010000005A000081A4000000000000000000000001679A649F000025E5000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/dns64.gopackage proxy
import (
"fmt"
"net"
"net/netip"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
const (
// maxNAT64PrefixBitLen is the maximum length of a NAT64 prefix in bits.
// See https://datatracker.ietf.org/doc/html/rfc6147#section-5.2.
maxNAT64PrefixBitLen = 96
// NAT64PrefixLength is the length of a NAT64 prefix in bytes.
NAT64PrefixLength = net.IPv6len - net.IPv4len
// maxDNS64SynTTL is the maximum TTL for synthesized DNS64 responses with no
// SOA records in seconds.
//
// If the SOA RR was not delivered with the negative response to the AAAA
// query, then the DNS64 SHOULD use the TTL of the original A RR or 600
// seconds, whichever is shorter.
//
// See https://datatracker.ietf.org/doc/html/rfc6147#section-5.1.7.
maxDNS64SynTTL uint32 = 600
)
// setupDNS64 initializes DNS64 settings, the NAT64 prefixes in particular. If
// the DNS64 feature is enabled and no prefixes are configured, the default
// Well-Known Prefix is used, just like Section 5.2 of RFC 6147 prescribes. Any
// configured set of prefixes discards the default Well-Known prefix unless it
// is specified explicitly. Each prefix also validated to be a valid IPv6 CIDR
// with a maximum length of 96 bits. The first specified prefix is then used to
// synthesize AAAA records.
func (p *Proxy) setupDNS64() (err error) {
if !p.Config.UseDNS64 {
return nil
}
if len(p.Config.DNS64Prefs) == 0 {
p.dns64Prefs = netutil.SliceSubnetSet{dns64WellKnownPref}
return nil
}
for i, pref := range p.Config.DNS64Prefs {
if !pref.Addr().Is6() {
return fmt.Errorf("prefix at index %d: %q is not an IPv6 prefix", i, pref)
}
if pref.Bits() > maxNAT64PrefixBitLen {
return fmt.Errorf("prefix at index %d: %q is too long for DNS64", i, pref)
}
p.dns64Prefs = append(p.dns64Prefs, pref.Masked())
}
return nil
}
// checkDNS64 checks if DNS64 should be performed. It returns a DNS64 request
// to resolve or nil if DNS64 is not desired. It also filters resp to not
// contain any NAT64 excluded addresses in the answer section, if needed. Both
// req and resp must not be nil.
//
// See https://datatracker.ietf.org/doc/html/rfc6147.
func (p *Proxy) checkDNS64(req, resp *dns.Msg) (dns64Req *dns.Msg) {
if len(p.dns64Prefs) == 0 {
return nil
}
q := req.Question[0]
if q.Qtype != dns.TypeAAAA || q.Qclass != dns.ClassINET {
// DNS64 operation for classes other than IN is undefined, and a DNS64
// MUST behave as though no DNS64 function is configured.
return nil
}
switch resp.Rcode {
case dns.RcodeNameError:
// A result with RCODE=3 (Name Error) is handled according to normal DNS
// operation (which is normally to return the error to the client).
return nil
case dns.RcodeSuccess:
// If resolver receives an answer with at least one AAAA record
// containing an address outside any of the excluded range(s), then it
// by default SHOULD build an answer section for a response including
// only the AAAA record(s) that do not contain any of the addresses
// inside the excluded ranges.
var hasAnswers bool
if resp.Answer, hasAnswers = p.filterNAT64Answers(resp.Answer); hasAnswers {
return nil
}
default:
// Any other RCODE is treated as though the RCODE were 0 and the answer
// section were empty.
}
dns64Req = req.Copy()
dns64Req.Id = dns.Id()
dns64Req.Question[0].Qtype = dns.TypeA
return dns64Req
}
// filterNAT64Answers filters out AAAA records that are within one of NAT64
// exclusion prefixes. hasAnswers is true if the filtered slice contains at
// least a single AAAA answer not within the prefixes or a CNAME.
//
// TODO(e.burkov): Remove prefs from args when old API is removed.
func (p *Proxy) filterNAT64Answers(rrs []dns.RR) (filtered []dns.RR, hasAnswers bool) {
filtered = make([]dns.RR, 0, len(rrs))
for _, ans := range rrs {
switch ans := ans.(type) {
case *dns.AAAA:
addr, err := netutil.IPToAddrNoMapped(ans.AAAA)
if err != nil {
p.logger.Error("bad aaaa record", slogutil.KeyError, err)
} else if p.dns64Prefs.Contains(addr) {
// Filter the record.
continue
} else {
filtered, hasAnswers = append(filtered, ans), true
}
case *dns.CNAME, *dns.DNAME:
// If the response contains a CNAME or a DNAME, then the CNAME or
// DNAME chain is followed until the first terminating A or AAAA
// record is reached.
//
// Just treat CNAME and DNAME responses as passable answers since
// AdGuard Home doesn't follow any of these chains except the
// dnsrewrite-defined ones.
filtered, hasAnswers = append(filtered, ans), true
default:
filtered = append(filtered, ans)
}
}
return filtered, hasAnswers
}
// synthDNS64 synthesizes a DNS64 response using the original response as a
// basis and modifying it with data from resp. It returns true if the response
// was actually modified.
func (p *Proxy) synthDNS64(origReq, origResp, resp *dns.Msg) (ok bool) {
if len(resp.Answer) == 0 {
// If there is an empty answer, then the DNS64 responds to the original
// querying client with the answer the DNS64 received to the original
// (initiator's) query.
return false
}
// The Time to Live (TTL) field is set to the minimum of the TTL of the
// original A RR and the SOA RR for the queried domain. If the original
// response contains no SOA records, the minimum of the TTL of the original
// A RR and [maxDNS64SynTTL] should be used. See [maxDNS64SynTTL].
soaTTL := maxDNS64SynTTL
for _, rr := range origResp.Ns {
if hdr := rr.Header(); hdr.Rrtype == dns.TypeSOA && hdr.Name == origReq.Question[0].Name {
soaTTL = hdr.Ttl
break
}
}
newAns := make([]dns.RR, 0, len(resp.Answer))
for _, ans := range resp.Answer {
rr := p.synthRR(ans, soaTTL)
if rr == nil {
// The error should have already been logged.
return false
}
newAns = append(newAns, rr)
}
origResp.Answer = newAns
origResp.Ns = resp.Ns
origResp.Extra = resp.Extra
return true
}
// dns64WellKnownPref is the default prefix to use in an algorithmic mapping for
// DNS64. See https://datatracker.ietf.org/doc/html/rfc6052#section-2.1.
var dns64WellKnownPref = netip.MustParsePrefix("64:ff9b::/96")
// shouldStripDNS64 returns true if DNS64 is enabled and req is a PTR for a
// reversed address within either one of custom DNS64 prefixes or the Well-Known
// one.
//
// The requirement is to match any Pref64::/n used at the site, and not merely
// the locally configured Pref64::/n. This is because end clients could ask for
// a PTR record matching an address received through a different (site-provided)
// DNS64.
//
// See https://datatracker.ietf.org/doc/html/rfc6147#section-5.3.1.
func (p *Proxy) shouldStripDNS64(req *dns.Msg) (ok bool) {
if len(p.dns64Prefs) == 0 {
return false
}
q := req.Question[0]
if q.Qtype != dns.TypePTR {
return false
}
host := q.Name
ip, err := netutil.IPFromReversedAddr(host)
if err != nil {
p.logger.Debug("failed to parse ip from ptr request", slogutil.KeyError, err)
return false
}
switch {
case p.dns64Prefs.Contains(ip):
p.logger.Debug("the ip is within dns64 custom prefix set", "ip", ip)
case dns64WellKnownPref.Contains(ip):
p.logger.Debug("the ip is within dns64 well-known prefix", "ip", ip)
default:
return false
}
return true
}
// mapDNS64 maps addr to IPv6 address using configured DNS64 prefix. addr must
// be a valid IPv4. It panics, if there are no configured DNS64 prefixes,
// because synthesis should not be performed unless DNS64 function enabled.
//
// TODO(e.burkov): Remove pref from args when old API is removed.
func (p *Proxy) mapDNS64(addr netip.Addr) (mapped net.IP) {
// Don't mask the address here since it should have already been masked on
// initialization stage.
prefData := p.dns64Prefs[0].Addr().As16()
addrData := addr.As4()
mapped = make(net.IP, net.IPv6len)
copy(mapped[:NAT64PrefixLength], prefData[:])
copy(mapped[NAT64PrefixLength:], addrData[:])
return mapped
}
// synthRR synthesizes a DNS64 resource record in compliance with RFC 6147. If
// rr is not an A record, it's returned as is. A records are modified to become
// a DNS64-synthesized AAAA records, and the TTL is set according to the
// original TTL of a record and soaTTL. It returns nil on invalid A records.
func (p *Proxy) synthRR(rr dns.RR, soaTTL uint32) (result dns.RR) {
aResp, ok := rr.(*dns.A)
if !ok {
return rr
}
addr, err := netutil.IPToAddr(aResp.A, netutil.AddrFamilyIPv4)
if err != nil {
p.logger.Error("bad a record", slogutil.KeyError, err)
return nil
}
aaaa := &dns.AAAA{
Hdr: dns.RR_Header{
Name: aResp.Hdr.Name,
Rrtype: dns.TypeAAAA,
Class: aResp.Hdr.Class,
Ttl: min(aResp.Hdr.Ttl, soaTTL),
},
AAAA: p.mapDNS64(addr),
}
return aaaa
}
// performDNS64 returns the upstream that was used to perform DNS64 request, or
// nil, if the request was not performed.
func (p *Proxy) performDNS64(
origReq *dns.Msg,
origResp *dns.Msg,
upstreams []upstream.Upstream,
) (u upstream.Upstream) {
if origResp == nil {
return nil
}
dns64Req := p.checkDNS64(origReq, origResp)
if dns64Req == nil {
return nil
}
host := origReq.Question[0].Name
p.logger.Debug("received an empty aaaa response, checking dns64", "host", host)
dns64Resp, u, err := p.exchangeUpstreams(dns64Req, upstreams)
if err != nil {
p.logger.Error("dns64 request failed", slogutil.KeyError, err)
return nil
}
if dns64Resp != nil && p.synthDNS64(origReq, origResp, dns64Resp) {
p.logger.Debug("synthesized aaaa response", "host", host)
return u
}
return nil
}
0707010000005B000081A4000000000000000000000001679A649F00002784000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/dns64_test.gopackage proxy
import (
"context"
"net"
"net/netip"
"sync"
"testing"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const ipv4OnlyFqdn = "ipv4.only."
func TestDNS64Race(t *testing.T) {
ans := newRR(t, ipv4OnlyFqdn, dns.TypeA, 3600, net.ParseIP("1.2.3.4"))
ups := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = (&dns.Msg{}).SetReply(req)
if req.Question[0].Qtype == dns.TypeA {
resp.Answer = []dns.RR{dns.Copy(ans)}
}
return resp, nil
},
onAddress: func() (addr string) { return "fake.address" },
onClose: func() (err error) { return nil },
}
localUps := &fakeUpstream{
onExchange: func(_ *dns.Msg) (_ *dns.Msg, _ error) { panic("not implemented") },
onAddress: func() (addr string) { return "fake.address" },
onClose: func() (err error) { return nil },
}
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
UpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
PrivateRDNSUpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{localUps},
},
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
UseDNS64: true,
UsePrivateRDNS: true,
// Valid NAT-64 prefix for 2001:67c:27e4:15::64 server.
DNS64Prefs: []netip.Prefix{netip.MustParsePrefix("2001:67c:27e4:1064::/96")},
})
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
syncCh := make(chan struct{})
// Send requests.
g := &sync.WaitGroup{}
g.Add(testMessagesCount)
addr := dnsProxy.Addr(ProtoTCP).String()
for range testMessagesCount {
// The [dns.Conn] isn't safe for concurrent use despite the requirements
// from the [net.Conn] documentation.
var conn *dns.Conn
conn, err = dns.Dial("tcp", addr)
require.NoError(t, err)
go sendTestAAAAMessageAsync(conn, g, ipv4OnlyFqdn, syncCh)
}
close(syncCh)
g.Wait()
}
func sendTestAAAAMessageAsync(conn *dns.Conn, g *sync.WaitGroup, fqdn string, syncCh chan struct{}) {
pt := testutil.PanicT{}
defer g.Done()
req := (&dns.Msg{}).SetQuestion(fqdn, dns.TypeAAAA)
<-syncCh
err := conn.WriteMsg(req)
require.NoError(pt, err)
res, err := conn.ReadMsg()
require.NoError(pt, err)
require.Equal(pt, res.Rcode, dns.RcodeSuccess)
require.NotEmpty(pt, res.Answer)
require.IsType(pt, &dns.AAAA{}, res.Answer[0])
}
// newRR is a helper that creates a new dns.RR with the given name, qtype,
// ttl and value. It fails the test if the qtype is not supported or the type
// of value doesn't match the qtype.
func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns.RR) {
t.Helper()
switch qtype {
case dns.TypeA:
rr = &dns.A{A: testutil.RequireTypeAssert[net.IP](t, val)}
case dns.TypeAAAA:
rr = &dns.AAAA{AAAA: testutil.RequireTypeAssert[net.IP](t, val)}
case dns.TypeCNAME:
rr = &dns.CNAME{Target: testutil.RequireTypeAssert[string](t, val)}
case dns.TypeSOA:
rr = &dns.SOA{
Ns: "ns." + name,
Mbox: "hostmaster." + name,
Serial: 1,
Refresh: 1,
Retry: 1,
Expire: 1,
Minttl: 1,
}
case dns.TypePTR:
rr = &dns.PTR{Ptr: testutil.RequireTypeAssert[string](t, val)}
default:
t.Fatalf("unsupported qtype: %d", qtype)
}
*rr.Header() = dns.RR_Header{
Name: name,
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: ttl,
}
return rr
}
func TestProxy_Resolve_dns64(t *testing.T) {
const (
ipv6Domain = "ipv6.only."
soaDomain = "ipv4.soa."
mappedDomain = "filterable.ipv6."
anotherDomain = "another.domain."
pointedDomain = "local1234.ipv4."
globDomain = "real1234.ipv4."
)
someIPv4 := net.IP{1, 2, 3, 4}
someIPv6 := net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
mappedIPv6 := net.ParseIP("64:ff9b::102:304")
ptr64Domain, err := netutil.IPToReversedAddr(mappedIPv6)
require.NoError(t, err)
ptr64Domain = dns.Fqdn(ptr64Domain)
ptrGlobDomain, err := netutil.IPToReversedAddr(someIPv4)
require.NoError(t, err)
ptrGlobDomain = dns.Fqdn(ptrGlobDomain)
localCliAddr := netip.MustParseAddrPort("192.168.1.1:1234")
const (
sectionAnswer = iota
sectionAuthority
sectionAdditional
sectionsNum
)
// answerMap is a convenience alias for describing the upstream response for
// a given question type.
type answerMap = map[uint16][sectionsNum][]dns.RR
pt := testutil.PanicT{}
newUps := func(answers answerMap) (u upstream.Upstream) {
return &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
q := req.Question[0]
require.Contains(pt, answers, q.Qtype)
answer := answers[q.Qtype]
resp = (&dns.Msg{}).SetReply(req)
resp.Answer = answer[sectionAnswer]
resp.Ns = answer[sectionAuthority]
resp.Extra = answer[sectionAdditional]
return resp, nil
},
onAddress: func() (addr string) { return "fake.address" },
onClose: func() (err error) { return nil },
}
}
localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain)
localUps := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
require.Equal(pt, req.Question[0].Name, ptr64Domain)
resp = (&dns.Msg{}).SetReply(req)
resp.Answer = []dns.RR{localRR}
return resp, nil
},
onAddress: func() (addr string) { return "fake.local.address" },
onClose: func() (err error) { return nil },
}
testCases := []struct {
name string
qname string
upsAns answerMap
wantAns []dns.RR
qtype uint16
}{{
name: "simple_a",
qname: ipv4OnlyFqdn,
upsAns: answerMap{
dns.TypeA: {
sectionAnswer: {newRR(t, ipv4OnlyFqdn, dns.TypeA, 3600, someIPv4)},
},
dns.TypeAAAA: {},
},
wantAns: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: ipv4OnlyFqdn,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 3600,
},
A: someIPv4,
}},
qtype: dns.TypeA,
}, {
name: "simple_aaaa",
qname: ipv6Domain,
upsAns: answerMap{
dns.TypeA: {},
dns.TypeAAAA: {
sectionAnswer: {newRR(t, ipv6Domain, dns.TypeAAAA, 3600, someIPv6)},
},
},
wantAns: []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: ipv6Domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 3600,
},
AAAA: someIPv6,
}},
qtype: dns.TypeAAAA,
}, {
name: "actual_dns64",
qname: ipv4OnlyFqdn,
upsAns: answerMap{
dns.TypeA: {
sectionAnswer: {newRR(t, ipv4OnlyFqdn, dns.TypeA, 3600, someIPv4)},
},
dns.TypeAAAA: {},
},
wantAns: []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: ipv4OnlyFqdn,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: maxDNS64SynTTL,
},
AAAA: mappedIPv6,
}},
qtype: dns.TypeAAAA,
}, {
name: "actual_dns64_soattl",
qname: soaDomain,
upsAns: answerMap{
dns.TypeA: {
sectionAnswer: {newRR(t, soaDomain, dns.TypeA, 3600, someIPv4)},
},
dns.TypeAAAA: {
sectionAuthority: {newRR(t, soaDomain, dns.TypeSOA, maxDNS64SynTTL+50, nil)},
},
},
wantAns: []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: soaDomain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: maxDNS64SynTTL + 50,
},
AAAA: mappedIPv6,
}},
qtype: dns.TypeAAAA,
}, {
name: "filtered",
qname: mappedDomain,
upsAns: answerMap{
dns.TypeA: {},
dns.TypeAAAA: {
sectionAnswer: {
newRR(t, mappedDomain, dns.TypeAAAA, 3600, net.ParseIP("64:ff9b::506:708")),
newRR(t, mappedDomain, dns.TypeCNAME, 3600, anotherDomain),
},
},
},
wantAns: []dns.RR{&dns.CNAME{
Hdr: dns.RR_Header{
Name: mappedDomain,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 3600,
},
Target: anotherDomain,
}},
qtype: dns.TypeAAAA,
}, {
name: "ptr",
qname: ptr64Domain,
upsAns: nil,
wantAns: []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: ptr64Domain,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 3600,
},
Ptr: pointedDomain,
}},
qtype: dns.TypePTR,
}, {
name: "ptr_glob",
qname: ptrGlobDomain,
upsAns: answerMap{
dns.TypePTR: {
sectionAnswer: {newRR(t, ptrGlobDomain, dns.TypePTR, 3600, globDomain)},
},
},
wantAns: []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: ptrGlobDomain,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 3600,
},
Ptr: globDomain,
}},
qtype: dns.TypePTR,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
p := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{newUps(tc.upsAns)},
},
PrivateRDNSUpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{localUps},
},
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
CacheEnabled: true,
UseDNS64: true,
UsePrivateRDNS: true,
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
ctx := context.Background()
err = p.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) })
dctx := &DNSContext{
Req: (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype),
Addr: localCliAddr,
}
err = p.handleDNSRequest(dctx)
require.NoError(t, err)
res := dctx.Res
require.NotNil(t, res)
assert.Equal(t, tc.wantAns, res.Answer)
})
}
}
0707010000005C000081A4000000000000000000000001679A649F00001DDF000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/dnscontext.gopackage proxy
import (
"net"
"net/http"
"net/netip"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/ameshkov/dnscrypt/v2"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
// DNSContext represents a DNS request message context
type DNSContext struct {
// Conn is the underlying client connection. It is nil if Proto is
// ProtoDNSCrypt, ProtoHTTPS, or ProtoQUIC.
Conn net.Conn
// QUICConnection is the QUIC session from which we got the query. For
// ProtoQUIC only.
QUICConnection quic.Connection
// QUICStream is the QUIC stream from which we got the query. For
// [ProtoQUIC] only.
QUICStream quic.Stream
// Upstream is the upstream that resolved the request. In case of cached
// response it's nil.
Upstream upstream.Upstream
// DNSCryptResponseWriter - necessary to respond to a DNSCrypt query
DNSCryptResponseWriter dnscrypt.ResponseWriter
// HTTPResponseWriter - HTTP response writer (for DoH only)
HTTPResponseWriter http.ResponseWriter
// HTTPRequest - HTTP request (for DoH only)
HTTPRequest *http.Request
// ReqECS is the EDNS Client Subnet used in the request.
ReqECS *net.IPNet
// CustomUpstreamConfig is the upstreams configuration used only for current
// request. The Resolve method of Proxy uses it instead of the default
// servers if it's not nil.
CustomUpstreamConfig *CustomUpstreamConfig
// queryStatistics contains the DNS query statistics for both the upstream
// and fallback DNS servers.
queryStatistics *QueryStatistics
// Req is the request message.
Req *dns.Msg
// Res is the response message.
Res *dns.Msg
// Proto is the DNS protocol of the query.
Proto Proto
// RequestedPrivateRDNS is the subnet extracted from the ARPA domain of
// request's question if it's a PTR, SOA, or NS query for a private IP
// address. It can be a single-address subnet as well as a zero-length one.
RequestedPrivateRDNS netip.Prefix
// localIP - local IP address (for UDP socket to call udpMakeOOBWithSrc)
localIP netip.Addr
// Addr is the address of the client.
Addr netip.AddrPort
// DoQVersion is the DoQ protocol version. It can (and should) be read from
// ALPN, but in the current version we also use the way DNS messages are
// encoded as a signal.
DoQVersion DoQVersion
// RequestID is an opaque numerical identifier of this request that is
// guaranteed to be unique across requests processed by a single Proxy
// instance.
RequestID uint64
// udpSize is the UDP buffer size from request's EDNS0 RR if presented,
// or default otherwise.
udpSize uint16
// IsPrivateClient is true if the client's address is considered private
// according to the configured private subnet set.
IsPrivateClient bool
// adBit is the authenticated data flag from the request.
adBit bool
// hasEDNS0 reflects if the request has EDNS0 RRs.
hasEDNS0 bool
// doBit is the DNSSEC OK flag from request's EDNS0 RR if presented.
doBit bool
}
// newDNSContext returns a new properly initialized *DNSContext.
//
// TODO(e.burkov): Consider creating DNSContext with this everywhere, to
// actually respect the contract of DNSContext.RequestID field.
func (p *Proxy) newDNSContext(proto Proto, req *dns.Msg, addr netip.AddrPort) (d *DNSContext) {
return &DNSContext{
Proto: proto,
Req: req,
Addr: addr,
RequestID: p.counter.Add(1),
}
}
// QueryStatistics returns the DNS query statistics for both the upstream and
// fallback DNS servers. The returned statistics will be nil until a DNS lookup
// has been performed.
//
// Depending on whether the DNS request was successfully resolved and the
// upstream mode, the returned statistics consist of:
//
// - If the query was successfully resolved, the statistics contain the DNS
// lookup duration for the main resolver.
//
// - If the query was retrieved from the cache, the statistics will contain a
// single entry of [UpstreamStatistics] where the property IsCached is set
// to true.
//
// - If the upstream mode is [UpstreamModeFastestAddr] and the query was
// successfully resolved, the statistics contain the DNS lookup durations or
// errors for each main upstream.
//
// - If the query was resolved by the fallback resolver, the statistics
// contain the DNS lookup errors for each main upstream and the query
// duration for the fallback resolver.
//
// - If the query was not resolved at all, the statistics contain the DNS
// lookup errors for each main and fallback resolvers.
func (dctx *DNSContext) QueryStatistics() (s *QueryStatistics) {
return dctx.queryStatistics
}
// calcFlagsAndSize lazily calculates some values required for Resolve method.
func (dctx *DNSContext) calcFlagsAndSize() {
if dctx.udpSize != 0 || dctx.Req == nil {
return
}
dctx.adBit = dctx.Req.AuthenticatedData
dctx.udpSize = defaultUDPBufSize
if o := dctx.Req.IsEdns0(); o != nil {
dctx.hasEDNS0 = true
dctx.doBit = o.Do()
dctx.udpSize = o.UDPSize()
}
}
// scrub prepares the d.Res to be written. Truncation is applied as well if
// necessary.
func (dctx *DNSContext) scrub() {
if dctx.Res == nil || dctx.Req == nil {
return
}
// We should guarantee that all the values we need are calculated.
dctx.calcFlagsAndSize()
// RFC-6891 (https://tools.ietf.org/html/rfc6891) states that response
// mustn't contain an EDNS0 RR if the request doesn't include it.
//
// See https://github.com/AdguardTeam/dnsproxy/issues/132.
if dctx.hasEDNS0 && dctx.Res.IsEdns0() == nil {
dctx.Res.SetEdns0(dctx.udpSize, dctx.doBit)
}
dctx.Res.Truncate(int(dnsSize(dctx.Proto == ProtoUDP, dctx.Req)))
// Some devices require DNS message compression.
dctx.Res.Compress = true
}
// dnsSize returns the buffer size advertised in the requests OPT record. When
// the request is over TCP, it returns the maximum allowed size of 64KiB.
func dnsSize(isUDP bool, r *dns.Msg) (size uint16) {
if !isUDP {
return dns.MaxMsgSize
}
var size16 uint16
if o := r.IsEdns0(); o != nil {
size16 = o.UDPSize()
}
return max(dns.MinMsgSize, size16)
}
// DoQVersion is an enumeration with supported DoQ versions.
type DoQVersion int
const (
// DoQv1Draft represents old DoQ draft versions that do not send a 2-octet
// prefix with the DNS message length.
//
// TODO(ameshkov): remove in the end of 2024.
DoQv1Draft DoQVersion = 0x00
// DoQv1 represents DoQ v1.0: https://www.rfc-editor.org/rfc/rfc9250.html.
DoQv1 DoQVersion = 0x01
)
// CustomUpstreamConfig contains upstreams configuration with an optional cache.
type CustomUpstreamConfig struct {
// upstream is the upstream configuration.
upstream *UpstreamConfig
// cache is an optional cache for upstreams in the current configuration.
// It is disabled if nil.
//
// TODO(d.kolyshev): Move this cache to [UpstreamConfig].
cache *cache
}
// NewCustomUpstreamConfig returns new custom upstream configuration.
func NewCustomUpstreamConfig(
u *UpstreamConfig,
cacheEnabled bool,
cacheSize int,
enableEDNSClientSubnet bool,
) (c *CustomUpstreamConfig) {
var customCache *cache
if cacheEnabled {
// TODO(d.kolyshev): Support optimistic with newOptimisticResolver.
customCache = newCache(cacheSize, enableEDNSClientSubnet, false)
}
return &CustomUpstreamConfig{
upstream: u,
cache: customCache,
}
}
// Close closes the custom upstream config.
func (c *CustomUpstreamConfig) Close() (err error) {
if c.upstream == nil {
return nil
}
return c.upstream.Close()
}
// ClearCache removes all items from the cache.
func (c *CustomUpstreamConfig) ClearCache() {
if c.cache == nil {
return
}
c.cache.clearItems()
c.cache.clearItemsWithSubnet()
}
0707010000005D000081A4000000000000000000000001679A649F00000219000000000000000000000000000000000000002000000000dnsproxy-0.75.0/proxy/errors.go//go:build !plan9
// +build !plan9
package proxy
import (
"syscall"
"github.com/AdguardTeam/golibs/errors"
)
// isEPIPE checks if the underlying error is EPIPE. syscall.EPIPE exists on all
// OSes except for Plan 9. Validate with:
//
// $ for os in $(go tool dist list | cut -d / -f 1 | sort -u)
// do
// echo -n "$os"
// env GOOS="$os" go doc syscall.EPIPE | grep -F -e EPIPE
// done
//
// For the Plan 9 version see ./errors_plan9.go.
func isEPIPE(err error) (ok bool) {
return errors.Is(err, syscall.EPIPE)
}
0707010000005E000081A4000000000000000000000001679A649F00000232000000000000000000000000000000000000002600000000dnsproxy-0.75.0/proxy/errors_plan9.go//go:build plan9
// +build plan9
package proxy
import "strings"
// isEPIPE checks if the underlying error is EPIPE. Plan 9 relies on error
// strings instead of error codes. I couldn't find the exact constant with the
// text returned by a write on a closed socket, but it seems to be "sys: write
// on closed pipe". See Plan 9's "man 2 notify".
//
// We don't currently support Plan 9, so it's not critical, but when we do, this
// needs to be rechecked.
func isEPIPE(err error) (ok bool) {
return strings.Contains(err.Error(), "write on closed pipe")
}
0707010000005F000081A4000000000000000000000001679A649F000002D1000000000000000000000000000000000000002500000000dnsproxy-0.75.0/proxy/errors_test.go//go:build !plan9
// +build !plan9
package proxy
import (
"fmt"
"syscall"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
)
func TestIsEPIPE(t *testing.T) {
type testCase struct {
err error
name string
want bool
}
testCases := []testCase{{
name: "nil",
err: nil,
want: false,
}, {
name: "epipe",
err: syscall.EPIPE,
want: true,
}, {
name: "not_epipe",
err: errors.Error("test error"),
want: false,
}, {
name: "wrapped_epipe",
err: fmt.Errorf("test error: %w", syscall.EPIPE),
want: true,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := isEPIPE(tc.err)
assert.Equal(t, tc.want, got)
})
}
}
07070100000060000081A4000000000000000000000001679A649F00001011000000000000000000000000000000000000002200000000dnsproxy-0.75.0/proxy/exchange.gopackage proxy
import (
"fmt"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"gonum.org/v1/gonum/stat/sampleuv"
)
// exchangeUpstreams resolves req using the given upstreams. It returns the DNS
// response, the upstream that successfully resolved the request, and the error
// if any.
func (p *Proxy) exchangeUpstreams(
req *dns.Msg,
ups []upstream.Upstream,
) (resp *dns.Msg, u upstream.Upstream, err error) {
switch p.UpstreamMode {
case UpstreamModeParallel:
return upstream.ExchangeParallel(ups, req)
case UpstreamModeFastestAddr:
switch req.Question[0].Qtype {
case dns.TypeA, dns.TypeAAAA:
return p.fastestAddr.ExchangeFastest(req, ups)
default:
// Go on to the load-balancing mode.
}
default:
// Go on to the load-balancing mode.
}
if len(ups) == 1 {
u = ups[0]
resp, _, err = p.exchange(u, req, p.time)
if err != nil {
return nil, nil, err
}
// TODO(e.burkov): Consider updating the RTT of a single upstream.
return resp, u, err
}
w := sampleuv.NewWeighted(p.calcWeights(ups), p.randSrc)
var errs []error
for i, ok := w.Take(); ok; i, ok = w.Take() {
u = ups[i]
var elapsed time.Duration
resp, elapsed, err = p.exchange(u, req, p.time)
if err == nil {
p.updateRTT(u.Address(), elapsed)
return resp, u, nil
}
errs = append(errs, err)
// TODO(e.burkov): Use the actual configured timeout or, perhaps, the
// actual measured elapsed time.
p.updateRTT(u.Address(), defaultTimeout)
}
err = fmt.Errorf("all upstreams failed to exchange request: %w", errors.Join(errs...))
return nil, nil, err
}
// exchange returns the result of the DNS request exchange with the given
// upstream and the elapsed time in milliseconds. It uses the given clock to
// measure the request duration.
func (p *Proxy) exchange(
u upstream.Upstream,
req *dns.Msg,
c clock,
) (resp *dns.Msg, dur time.Duration, err error) {
startTime := c.Now()
resp, err = u.Exchange(req)
// Don't use [time.Since] because it uses [time.Now].
dur = c.Now().Sub(startTime)
addr := u.Address()
q := &req.Question[0]
if err != nil {
p.logger.Error(
"exchange failed",
"upstream", addr,
"question", q,
"duration", dur,
slogutil.KeyError, err,
)
} else {
p.logger.Debug(
"exchange successfully finished",
"upstream", addr,
"question", q,
"duration", dur,
)
}
return resp, dur, err
}
// upstreamRTTStats is the statistics for a single upstream's round-trip time.
type upstreamRTTStats struct {
// rttSum is the sum of all the round-trip times in microseconds. The
// float64 type is used since it's capable of representing about 285 years
// in microseconds.
rttSum float64
// reqNum is the number of requests to the upstream. The float64 type is
// used since to avoid unnecessary type conversions.
reqNum float64
}
// update returns updated stats after adding given RTT.
func (stats upstreamRTTStats) update(rtt time.Duration) (updated upstreamRTTStats) {
return upstreamRTTStats{
rttSum: stats.rttSum + float64(rtt.Microseconds()),
reqNum: stats.reqNum + 1,
}
}
// calcWeights returns the slice of weights, each corresponding to the upstream
// with the same index in the given slice.
func (p *Proxy) calcWeights(ups []upstream.Upstream) (weights []float64) {
weights = make([]float64, 0, len(ups))
p.rttLock.Lock()
defer p.rttLock.Unlock()
for _, u := range ups {
stat := p.upstreamRTTStats[u.Address()]
if stat.rttSum == 0 || stat.reqNum == 0 {
// Use 1 as the default weight.
weights = append(weights, 1)
} else {
weights = append(weights, 1/(stat.rttSum/stat.reqNum))
}
}
return weights
}
// updateRTT updates the round-trip time in [upstreamRTTStats] for given
// address.
func (p *Proxy) updateRTT(address string, rtt time.Duration) {
p.rttLock.Lock()
defer p.rttLock.Unlock()
if p.upstreamRTTStats == nil {
p.upstreamRTTStats = map[string]upstreamRTTStats{}
}
p.upstreamRTTStats[address] = p.upstreamRTTStats[address].update(rtt)
}
07070100000061000081A4000000000000000000000001679A649F00001A0F000000000000000000000000000000000000003000000000dnsproxy-0.75.0/proxy/exchange_internal_test.gopackage proxy
import (
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/rand"
)
// fakeClock is the function-based implementation of the [clock] interface.
type fakeClock struct {
onNow func() (now time.Time)
}
// type check
var _ clock = (*fakeClock)(nil)
// Now implements the [clock] interface for *fakeClock.
func (c *fakeClock) Now() (now time.Time) { return c.onNow() }
// newUpstreamWithErrorRate returns an [upstream.Upstream] that responds with an
// error every [rate] requests. The returned upstream isn't safe for concurrent
// use.
func newUpstreamWithErrorRate(rate uint, name string) (u upstream.Upstream) {
var n uint
return &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
n++
if n%rate == 0 {
return nil, assert.AnError
}
return (&dns.Msg{}).SetReply(req), nil
},
onAddress: func() (addr string) { return name },
onClose: func() (_ error) { panic("not implemented") },
}
}
// measuredUpstream is an [upstream.Upstream] that increments the counter every
// time it's used.
type measuredUpstream struct {
// Upstream is embedded here to avoid implementing all the methods.
upstream.Upstream
// stats is the statistics collector for current upstream.
stats map[string]int64
}
// type check
var _ upstream.Upstream = measuredUpstream{}
// Exchange implements the [upstream.Upstream] interface for measuredUpstream.
func (u measuredUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
u.stats[u.Address()]++
return u.Upstream.Exchange(req)
}
func TestProxy_Exchange_loadBalance(t *testing.T) {
// Make the test deterministic.
randSrc := rand.NewSource(42)
const (
testRTT = 1 * time.Second
requestsNum = 10_000
)
// zeroingClock returns the value of currentNow and sets it back to
// zeroTime, so that all the calls since the second one return the same zero
// value until currentNow is modified elsewhere.
zeroTime := time.Unix(0, 0)
currentNow := zeroTime
zeroingClock := &fakeClock{
onNow: func() (now time.Time) {
now, currentNow = currentNow, zeroTime
return now
},
}
constClock := &fakeClock{
onNow: func() (now time.Time) {
now, currentNow = currentNow, currentNow.Add(testRTT/50)
return now
},
}
fastUps := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
currentNow = zeroTime.Add(testRTT / 100)
return (&dns.Msg{}).SetReply(req), nil
},
onAddress: func() (addr string) { return "fast" },
onClose: func() (_ error) { panic("not implemented") },
}
slowerUps := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
currentNow = zeroTime.Add(testRTT / 10)
return (&dns.Msg{}).SetReply(req), nil
},
onAddress: func() (addr string) { return "slower" },
onClose: func() (_ error) { panic("not implemented") },
}
slowestUps := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
currentNow = zeroTime.Add(testRTT / 2)
return (&dns.Msg{}).SetReply(req), nil
},
onAddress: func() (addr string) { return "slowest" },
onClose: func() (_ error) { panic("not implemented") },
}
err1Ups := &fakeUpstream{
onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError },
onAddress: func() (addr string) { return "error1" },
onClose: func() (_ error) { panic("not implemented") },
}
err2Ups := &fakeUpstream{
onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError },
onAddress: func() (addr string) { return "error2" },
onClose: func() (_ error) { panic("not implemented") },
}
singleError := &sync.Once{}
// fastestUps responds with an error on the first request.
fastestUps := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
singleError.Do(func() { err = assert.AnError })
currentNow = zeroTime.Add(testRTT / 200)
return (&dns.Msg{}).SetReply(req), err
},
onAddress: func() (addr string) { return "fastest" },
onClose: func() (_ error) { panic("not implemented") },
}
each200 := newUpstreamWithErrorRate(200, "each_200")
each100 := newUpstreamWithErrorRate(100, "each_100")
each50 := newUpstreamWithErrorRate(50, "each_50")
testCases := []struct {
wantStat map[string]int64
clock clock
name string
servers []upstream.Upstream
}{{
wantStat: map[string]int64{
fastUps.Address(): 8917,
slowerUps.Address(): 911,
slowestUps.Address(): 172,
},
clock: zeroingClock,
name: "all_good",
servers: []upstream.Upstream{slowestUps, slowerUps, fastUps},
}, {
wantStat: map[string]int64{
fastUps.Address(): 9081,
slowerUps.Address(): 919,
err1Ups.Address(): 7,
},
clock: zeroingClock,
name: "one_bad",
servers: []upstream.Upstream{fastUps, err1Ups, slowerUps},
}, {
wantStat: map[string]int64{
err1Ups.Address(): requestsNum,
err2Ups.Address(): requestsNum,
},
clock: zeroingClock,
name: "all_bad",
servers: []upstream.Upstream{err2Ups, err1Ups},
}, {
wantStat: map[string]int64{
fastUps.Address(): 7803,
slowerUps.Address(): 833,
fastestUps.Address(): 1365,
},
clock: zeroingClock,
name: "error_once",
servers: []upstream.Upstream{fastUps, slowerUps, fastestUps},
}, {
wantStat: map[string]int64{
each200.Address(): 5316,
each100.Address(): 3090,
each50.Address(): 1683,
},
clock: constClock,
name: "error_each_nth",
servers: []upstream.Upstream{each200, each100, each50},
}}
req := newTestMessage()
cli := netip.AddrPortFrom(netutil.IPv4Localhost(), 1234)
for _, tc := range testCases {
ups := []upstream.Upstream{}
stats := map[string]int64{}
for _, s := range tc.servers {
ups = append(ups, measuredUpstream{
Upstream: s,
stats: stats,
})
}
p := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: &UpstreamConfig{
Upstreams: ups,
},
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
})
p.time = tc.clock
p.randSrc = randSrc
wantStat := tc.wantStat
t.Run(tc.name, func(t *testing.T) {
for range requestsNum {
_ = p.Resolve(&DNSContext{Req: req, Addr: cli})
}
assert.Equal(t, wantStat, stats)
})
}
}
07070100000062000081A4000000000000000000000001679A649F000007D8000000000000000000000000000000000000002600000000dnsproxy-0.75.0/proxy/handler_test.gopackage proxy
import (
"context"
"net"
"sync"
"testing"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFilteringHandler(t *testing.T) {
// Initializing the test middleware
m := &sync.RWMutex{}
blockResponse := false
// Prepare the proxy server
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
RequestHandler: func(p *Proxy, d *DNSContext) error {
m.Lock()
defer m.Unlock()
if !blockResponse {
// Use the default Resolve method if response is not blocked
return p.Resolve(d)
}
resp := dns.Msg{}
resp.SetRcode(d.Req, dns.RcodeNotImplemented)
resp.RecursionAvailable = true
// Set the response right away
d.Res = &resp
return nil
},
})
// Start listening
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Create a DNS-over-UDP client connection
addr := dnsProxy.Addr(ProtoUDP)
client := &dns.Client{
Net: string(ProtoUDP),
Timeout: testTimeout,
}
// Send the first message (not blocked)
req := newTestMessage()
r, _, err := client.Exchange(req, addr.String())
require.NoError(t, err)
requireResponse(t, req, r)
// Now send the second and make sure it is blocked
m.Lock()
blockResponse = true
m.Unlock()
r, _, err = client.Exchange(req, addr.String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeNotImplemented, r.Rcode)
}
07070100000063000081A4000000000000000000000001679A649F00000907000000000000000000000000000000000000002100000000dnsproxy-0.75.0/proxy/helpers.gopackage proxy
import (
"net"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// ecsFromMsg returns the subnet from EDNS Client Subnet option of m if any.
func ecsFromMsg(m *dns.Msg) (subnet *net.IPNet, scope int) {
opt := m.IsEdns0()
if opt == nil {
return nil, 0
}
var ip net.IP
var mask net.IPMask
for _, e := range opt.Option {
sn, ok := e.(*dns.EDNS0_SUBNET)
if !ok {
continue
}
switch sn.Family {
case 1:
ip = sn.Address.To4()
mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv4BitLen)
case 2:
ip = sn.Address
mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv6BitLen)
default:
continue
}
return &net.IPNet{IP: ip, Mask: mask}, int(sn.SourceScope)
}
return nil, 0
}
// setECS sets the EDNS client subnet option based on ip and scope into m. It
// returns masked IP and mask length.
func setECS(m *dns.Msg, ip net.IP, scope uint8) (subnet *net.IPNet) {
const (
// defaultECSv4 is the default length of network mask for IPv4 address
// in ECS option.
defaultECSv4 = 24
// defaultECSv6 is the default length of network mask for IPv6 address
// in ECS. The size of 7 octets is chosen as a reasonable minimum since
// at least Google's public DNS refuses requests containing the options
// with longer network masks.
defaultECSv6 = 56
)
e := &dns.EDNS0_SUBNET{
Code: dns.EDNS0SUBNET,
SourceScope: scope,
}
subnet = &net.IPNet{}
if ip4 := ip.To4(); ip4 != nil {
e.Family = 1
e.SourceNetmask = defaultECSv4
subnet.Mask = net.CIDRMask(defaultECSv4, netutil.IPv4BitLen)
ip = ip4
} else {
// Assume the IP address has already been validated.
e.Family = 2
e.SourceNetmask = defaultECSv6
subnet.Mask = net.CIDRMask(defaultECSv6, netutil.IPv6BitLen)
}
subnet.IP = ip.Mask(subnet.Mask)
e.Address = subnet.IP
// If OPT record already exists so just add EDNS option inside it. Note
// that servers may return FORMERR if they meet several OPT RRs.
if opt := m.IsEdns0(); opt != nil {
opt.Option = append(opt.Option, e)
return subnet
}
// Create an OPT record and add EDNS option inside it.
o := &dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
},
Option: []dns.EDNS0{e},
}
o.SetUDPSize(4096)
m.Extra = append(m.Extra, o)
return subnet
}
07070100000064000081A4000000000000000000000001679A649F0000095C000000000000000000000000000000000000002000000000dnsproxy-0.75.0/proxy/lookup.gopackage proxy
import (
"context"
"net/netip"
"slices"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// helper struct to pass results of lookupIPAddr function
type lookupResult struct {
resp *dns.Msg
err error
}
// lookupIPAddr resolves the specified host IP addresses. It is intended to be
// used as a goroutine.
func (p *Proxy) lookupIPAddr(
ctx context.Context,
host string,
qtype uint16,
ch chan *lookupResult,
) {
defer slogutil.RecoverAndLog(ctx, p.logger)
req := (&dns.Msg{}).SetQuestion(host, qtype)
// TODO(d.kolyshev): Investigate why the client address is not defined.
d := p.newDNSContext(ProtoUDP, req, netip.AddrPort{})
err := p.Resolve(d)
ch <- &lookupResult{
resp: d.Res,
err: err,
}
}
// ErrEmptyHost is returned by LookupIPAddr when the host is empty and can't be
// resolved.
const ErrEmptyHost = errors.Error("host is empty")
// type check
var _ upstream.Resolver = (*Proxy)(nil)
// LookupNetIP implements the [upstream.Resolver] interface for *Proxy. It
// resolves the specified host IP addresses by sending two DNS queries (A and
// AAAA) in parallel. It returns both results for those two queries.
func (p *Proxy) LookupNetIP(
ctx context.Context,
_ string,
host string,
) (addrs []netip.Addr, err error) {
if host == "" {
return nil, ErrEmptyHost
}
host = dns.Fqdn(host)
ch := make(chan *lookupResult)
go p.lookupIPAddr(ctx, host, dns.TypeA, ch)
go p.lookupIPAddr(ctx, host, dns.TypeAAAA, ch)
var errs []error
for range 2 {
result := <-ch
if result.err != nil {
errs = append(errs, result.err)
continue
}
addrs = appendAnswerAddrs(addrs, result.resp.Answer)
}
if len(addrs) == 0 && len(errs) != 0 {
return addrs, errors.Join(errs...)
}
if p.Config.PreferIPv6 {
slices.SortStableFunc(addrs, netutil.PreferIPv6)
} else {
slices.SortStableFunc(addrs, netutil.PreferIPv4)
}
return addrs, nil
}
// appendAnswerAddrs returns addrs with addresses appended from the given ans.
func appendAnswerAddrs(addrs []netip.Addr, ans []dns.RR) (res []netip.Addr) {
for _, ansRR := range ans {
a := proxyutil.IPFromRR(ansRR)
if a != (netip.Addr{}) {
addrs = append(addrs, a)
}
}
return addrs
}
07070100000065000081A4000000000000000000000001679A649F0000046F000000000000000000000000000000000000002500000000dnsproxy-0.75.0/proxy/lookup_test.gopackage proxy
import (
"context"
"net/netip"
"testing"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLookupNetIP(t *testing.T) {
// Use AdGuard DNS here.
dnsUpstream, err := upstream.AddressToUpstream(
"94.140.14.14",
&upstream.Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: defaultTimeout,
},
)
require.NoError(t, err)
conf := &Config{
Logger: slogutil.NewDiscardLogger(),
UpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{dnsUpstream},
},
}
p, err := New(conf)
require.NoError(t, err)
// Now let's try doing some lookups.
addrs, err := p.LookupNetIP(context.Background(), "", "dns.google")
require.NoError(t, err)
require.NotEmpty(t, addrs)
assert.Contains(t, addrs, netip.MustParseAddr("8.8.8.8"))
assert.Contains(t, addrs, netip.MustParseAddr("8.8.4.4"))
if len(addrs) > 2 {
assert.Contains(t, addrs, netip.MustParseAddr("2001:4860:4860::8888"))
assert.Contains(t, addrs, netip.MustParseAddr("2001:4860:4860::8844"))
}
}
07070100000066000081A4000000000000000000000001679A649F00000704000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/proxy/optimisticresolver.gopackage proxy
import (
"context"
"encoding/hex"
"log/slog"
"sync"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// cachingResolver is the DNS resolver that is also able to cache responses.
type cachingResolver interface {
// replyFromUpstream returns true if the request from dctx is successfully
// resolved and the response may be cached.
//
// TODO(e.burkov): Find out when ok can be false with nil err.
replyFromUpstream(dctx *DNSContext) (ok bool, err error)
// cacheResp caches the response from dctx.
cacheResp(dctx *DNSContext)
}
// type check
var _ cachingResolver = (*Proxy)(nil)
// optimisticResolver is used to eventually resolve expired cached requests.
type optimisticResolver struct {
reqs *sync.Map
cr cachingResolver
}
// newOptimisticResolver returns the new resolver for expired cached requests.
// cr must not be nil.
func newOptimisticResolver(cr cachingResolver) (s *optimisticResolver) {
return &optimisticResolver{
reqs: &sync.Map{},
cr: cr,
}
}
// unit is a convenient alias for struct{}.
type unit = struct{}
// resolveOnce tries to resolve the request from dctx but only a single request
// with the same key at the same period of time. It runs in a separate
// goroutine. Do not pass the *DNSContext which is used elsewhere since it
// isn't intended to be used concurrently.
func (s *optimisticResolver) resolveOnce(dctx *DNSContext, key []byte, l *slog.Logger) {
defer slogutil.RecoverAndLog(context.TODO(), l)
keyHexed := hex.EncodeToString(key)
if _, ok := s.reqs.LoadOrStore(keyHexed, unit{}); ok {
return
}
defer s.reqs.Delete(keyHexed)
ok, err := s.cr.replyFromUpstream(dctx)
if err != nil {
l.Debug("resolving request for optimistic cache", slogutil.KeyError, err)
}
if ok {
s.cr.cacheResp(dctx)
}
}
07070100000067000081A4000000000000000000000001679A649F00000C0A000000000000000000000000000000000000003100000000dnsproxy-0.75.0/proxy/optimisticresolver_test.gopackage proxy
import (
"bytes"
"log/slog"
"sync"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/stretchr/testify/assert"
)
// testCachingResolver is a stub implementation of the cachingResolver interface
// to simplify testing.
type testCachingResolver struct {
onReplyFromUpstream func(dctx *DNSContext) (ok bool, err error)
onCacheResp func(dctx *DNSContext)
}
// replyFromUpstream implements the cachingResolver interface for
// *testCachingResolver.
func (tcr *testCachingResolver) replyFromUpstream(dctx *DNSContext) (ok bool, err error) {
return tcr.onReplyFromUpstream(dctx)
}
// cacheResp implements the cachingResolver interface for *testCachingResolver.
func (tcr *testCachingResolver) cacheResp(dctx *DNSContext) {
tcr.onCacheResp(dctx)
}
func TestOptimisticResolver_ResolveOnce(t *testing.T) {
in, out := make(chan unit), make(chan unit)
var timesResolved, timesSet int
tcr := &testCachingResolver{
onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) {
timesResolved++
return true, nil
},
onCacheResp: func(_ *DNSContext) {
timesSet++
// Pass the signal to begin running secondary goroutines.
out <- unit{}
// Block until all the secondary goroutines finish.
<-in
},
}
s := newOptimisticResolver(tcr)
sameKey := []byte{1, 2, 3}
// Start the primary goroutine.
go s.resolveOnce(nil, sameKey, slogutil.NewDiscardLogger())
// Block until the primary goroutine reaches the resolve function.
<-out
wg := &sync.WaitGroup{}
const secondaryNum = 10
wg.Add(secondaryNum)
for range secondaryNum {
go func() {
defer wg.Done()
s.resolveOnce(nil, sameKey, slogutil.NewDiscardLogger())
}()
}
// Wait until all the secondary goroutines are finished.
wg.Wait()
// Pass the signal to terminate the primary goroutine.
in <- unit{}
assert.Equal(t, 1, timesResolved)
assert.Equal(t, 1, timesSet)
}
func TestOptimisticResolver_ResolveOnce_unsuccessful(t *testing.T) {
key := []byte{1, 2, 3}
t.Run("error", func(t *testing.T) {
// TODO(d.kolyshev): Consider adding mock handler to golibs.
logOutput := &bytes.Buffer{}
l := slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{
AddSource: false,
Level: slog.LevelDebug,
ReplaceAttr: nil,
}))
const rErr errors.Error = "sample resolving error"
cached := false
s := newOptimisticResolver(&testCachingResolver{
onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return true, rErr },
onCacheResp: func(_ *DNSContext) { cached = true },
})
s.resolveOnce(nil, key, l)
assert.True(t, cached)
assert.Contains(t, logOutput.String(), rErr.Error())
})
t.Run("not_ok", func(t *testing.T) {
cached := false
s := newOptimisticResolver(&testCachingResolver{
onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return false, nil },
onCacheResp: func(_ *DNSContext) { cached = true },
})
s.resolveOnce(nil, key, slogutil.NewDiscardLogger())
assert.False(t, cached)
})
}
07070100000068000081A4000000000000000000000001679A649F000050F9000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/proxy.go// Package proxy implements a DNS proxy that supports all known DNS encryption
// protocols.
package proxy
import (
"cmp"
"context"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/dnsproxy/fastip"
"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/service"
"github.com/AdguardTeam/golibs/syncutil"
"github.com/ameshkov/dnscrypt/v2"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"golang.org/x/exp/rand"
)
const (
defaultTimeout = 10 * time.Second
minDNSPacketSize = 12 + 5
)
// Proto is the DNS protocol.
type Proto string
// Proto values.
const (
// ProtoUDP is the plain DNS-over-UDP protocol.
ProtoUDP Proto = "udp"
// ProtoTCP is the plain DNS-over-TCP protocol.
ProtoTCP Proto = "tcp"
// ProtoTLS is the DNS-over-TLS (DoT) protocol.
ProtoTLS Proto = "tls"
// ProtoHTTPS is the DNS-over-HTTPS (DoH) protocol.
ProtoHTTPS Proto = "https"
// ProtoQUIC is the DNS-over-QUIC (DoQ) protocol.
ProtoQUIC Proto = "quic"
// ProtoDNSCrypt is the DNSCrypt protocol.
ProtoDNSCrypt Proto = "dnscrypt"
)
// Proxy combines the proxy server state and configuration.
//
// TODO(a.garipov): Consider extracting conf blocks for better fieldalignment.
type Proxy struct {
// requestsSema limits the number of simultaneous requests.
//
// TODO(a.garipov): Currently we have to pass this exact semaphore to the
// workers, to prevent races on restart. In the future we will need a
// better restarting mechanism that completely prevents such invalid states.
//
// See also: https://github.com/AdguardTeam/AdGuardHome/issues/2242.
requestsSema syncutil.Semaphore
// privateNets determines if the requested address and the client address
// are private.
privateNets netutil.SubnetSet
// time provides the current time.
//
// TODO(e.burkov): Consider configuring it.
time clock
// randSrc provides the source of randomness.
//
// TODO(e.burkov): Consider configuring it.
randSrc rand.Source
// messages constructs DNS messages.
messages MessageConstructor
// beforeRequestHandler handles the request's context before it is resolved.
beforeRequestHandler BeforeRequestHandler
// dnsCryptServer serves DNSCrypt queries.
dnsCryptServer *dnscrypt.Server
// logger is used for logging in the proxy service. It is never nil.
logger *slog.Logger
// ratelimitBuckets is a storage for ratelimiters for individual IPs.
ratelimitBuckets *gocache.Cache
// fastestAddr finds the fastest IP address for the resolved domain.
fastestAddr *fastip.FastestAddr
// cache is used to cache requests. It is disabled if nil.
//
// TODO(d.kolyshev): Move this cache to [Proxy.UpstreamConfig] field.
cache *cache
// shortFlighter is used to resolve the expired cached requests without
// repetitions.
shortFlighter *optimisticResolver
// recDetector detects recursive requests that may appear when resolving
// requests for private addresses.
recDetector *recursionDetector
// bytesPool is a pool of byte slices used to read DNS packets.
//
// TODO(e.burkov): Use [syncutil.Pool].
bytesPool *sync.Pool
// udpListen are the listened UDP connections.
udpListen []*net.UDPConn
// tcpListen are the listened TCP connections.
tcpListen []net.Listener
// tlsListen are the listened TCP connections with TLS.
tlsListen []net.Listener
// quicListen are the listened QUIC connections.
quicListen []*quic.EarlyListener
// quicConns are UDP connections for all listened QUIC connections. These
// should be closed on shutdown, since *quic.EarlyListener doesn't close
// them.
quicConns []*net.UDPConn
// quicTransports are transports for all listened QUIC connections. These
// should be closed on shutdown, since *quic.EarlyListener doesn't close
// them.
quicTransports []*quic.Transport
// httpsListen are the listened HTTPS connections.
httpsListen []net.Listener
// h3Listen are the listened HTTP/3 connections.
h3Listen []*quic.EarlyListener
// httpsServer serves queries received over HTTPS.
httpsServer *http.Server
// h3Server serves queries received over HTTP/3.
h3Server *http3.Server
// dnsCryptUDPListen are the listened UDP connections for DNSCrypt.
dnsCryptUDPListen []*net.UDPConn
// dnsCryptTCPListen are the listened TCP connections for DNSCrypt.
dnsCryptTCPListen []net.Listener
// upstreamRTTStats maps the upstream address to its round-trip time
// statistics. It's holds the statistics for all upstreams to perform a
// weighted random selection when using the load balancing mode.
upstreamRTTStats map[string]upstreamRTTStats
// dns64Prefs is a set of NAT64 prefixes that are used to detect and
// construct DNS64 responses. The DNS64 function is disabled if it is
// empty.
dns64Prefs netutil.SliceSubnetSet
// Config is the proxy configuration.
//
// TODO(a.garipov): Remove this embed and create a proper initializer.
Config
// udpOOBSize is the size of the out-of-band data for UDP connections.
udpOOBSize int
// counter counts message contexts created with [Proxy.newDNSContext].
counter atomic.Uint64
// RWMutex protects the whole proxy.
//
// TODO(e.burkov): Find out what exactly it protects and name it properly.
// Also make it a pointer.
sync.RWMutex
// ratelimitLock protects ratelimitBuckets.
ratelimitLock sync.Mutex
// rttLock protects upstreamRTTStats.
//
// TODO(e.burkov): Make it a pointer.
rttLock sync.Mutex
// started indicates if the proxy has been started.
started bool
}
// New creates a new Proxy with the specified configuration. c must not be nil.
//
// TODO(e.burkov): Cover with tests.
func New(c *Config) (p *Proxy, err error) {
p = &Proxy{
Config: *c,
privateNets: cmp.Or[netutil.SubnetSet](
c.PrivateSubnets,
netutil.SubnetSetFunc(netutil.IsLocallyServed),
),
beforeRequestHandler: cmp.Or[BeforeRequestHandler](
c.BeforeRequestHandler,
noopRequestHandler{},
),
upstreamRTTStats: map[string]upstreamRTTStats{},
rttLock: sync.Mutex{},
ratelimitLock: sync.Mutex{},
RWMutex: sync.RWMutex{},
bytesPool: &sync.Pool{
New: func() any {
// 2 bytes may be used to store packet length (see TCP/TLS).
b := make([]byte, 2+dns.MaxMsgSize)
return &b
},
},
udpOOBSize: proxynetutil.UDPGetOOBSize(),
time: realClock{},
messages: cmp.Or[MessageConstructor](
c.MessageConstructor,
dnsmsg.DefaultMessageConstructor{},
),
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
}
if c.Logger != nil {
p.logger = c.Logger
} else {
p.logger = slog.Default().With(slogutil.KeyPrefix, LogPrefix)
}
// TODO(e.burkov): Validate config separately and add the contract to the
// New function.
err = p.validateConfig()
if err != nil {
return nil, err
}
// TODO(s.chzhen): Consider moving to [Proxy.validateConfig].
err = p.validateBasicAuth()
if err != nil {
return nil, fmt.Errorf("basic auth: %w", err)
}
p.initCache()
if p.MaxGoroutines > 0 {
p.logger.Info("max goroutines is set", "count", p.MaxGoroutines)
p.requestsSema = syncutil.NewChanSemaphore(p.MaxGoroutines)
} else {
p.requestsSema = syncutil.EmptySemaphore{}
}
if p.UpstreamMode == "" {
p.UpstreamMode = UpstreamModeLoadBalance
} else if p.UpstreamMode == UpstreamModeFastestAddr {
p.fastestAddr = fastip.New(&fastip.Config{
Logger: p.Logger,
PingWaitTimeout: p.FastestPingTimeout,
})
}
err = p.setupDNS64()
if err != nil {
return nil, fmt.Errorf("setting up DNS64: %w", err)
}
p.RatelimitWhitelist = slices.Clone(p.RatelimitWhitelist)
slices.SortFunc(p.RatelimitWhitelist, netip.Addr.Compare)
return p, nil
}
// validateBasicAuth validates the basic-auth mode settings if p.Config.Userinfo
// is set.
func (p *Proxy) validateBasicAuth() (err error) {
conf := p.Config
if conf.Userinfo == nil {
return nil
}
if len(conf.HTTPSListenAddr) == 0 {
return errors.Error("no https addrs")
}
return nil
}
// Returns true if proxy is started. It is safe for concurrent use.
func (p *Proxy) isStarted() (ok bool) {
p.RLock()
defer p.RUnlock()
return p.started
}
// type check
var _ service.Interface = (*Proxy)(nil)
// Start implements the [service.Interface] for *Proxy.
func (p *Proxy) Start(ctx context.Context) (err error) {
p.logger.InfoContext(ctx, "starting dns proxy server")
p.Lock()
defer p.Unlock()
if p.started {
return errors.Error("server has been already started")
}
err = p.validateListenAddrs()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = p.configureListeners(ctx)
if err != nil {
return fmt.Errorf("configuring listeners: %w", err)
}
p.startListeners()
p.started = true
return nil
}
// closeAll closes all closers and appends the occurred errors to errs.
func closeAll[C io.Closer](errs []error, closers ...C) (appended []error) {
for _, c := range closers {
err := c.Close()
if err != nil {
errs = append(errs, err)
}
}
return errs
}
// Shutdown implements the [service.Interface] for *Proxy.
func (p *Proxy) Shutdown(ctx context.Context) (err error) {
p.logger.InfoContext(ctx, "stopping server")
p.Lock()
defer p.Unlock()
if !p.started {
// TODO(a.garipov): Consider returning err.
p.logger.WarnContext(ctx, "dns proxy server is not started")
return nil
}
errs := closeAll(nil, p.tcpListen...)
p.tcpListen = nil
errs = closeAll(errs, p.udpListen...)
p.udpListen = nil
errs = closeAll(errs, p.tlsListen...)
p.tlsListen = nil
if p.httpsServer != nil {
errs = closeAll(errs, p.httpsServer)
p.httpsServer = nil
// No need to close these since they're closed by httpsServer.Close().
p.httpsListen = nil
}
if p.h3Server != nil {
errs = closeAll(errs, p.h3Server)
p.h3Server = nil
}
errs = closeAll(errs, p.h3Listen...)
p.h3Listen = nil
errs = closeAll(errs, p.quicListen...)
p.quicListen = nil
errs = closeAll(errs, p.quicTransports...)
p.quicTransports = nil
errs = closeAll(errs, p.quicConns...)
p.quicConns = nil
errs = closeAll(errs, p.dnsCryptUDPListen...)
p.dnsCryptUDPListen = nil
errs = closeAll(errs, p.dnsCryptTCPListen...)
p.dnsCryptTCPListen = nil
for _, u := range []*UpstreamConfig{
p.UpstreamConfig,
p.PrivateRDNSUpstreamConfig,
p.Fallbacks,
} {
if u != nil {
errs = closeAll(errs, u)
}
}
p.started = false
p.logger.InfoContext(ctx, "stopped dns proxy server")
if len(errs) > 0 {
return fmt.Errorf("stopping dns proxy server: %w", errors.Join(errs...))
}
return nil
}
// addrFunc provides the address from the given A.
type addrFunc[A any] func(l A) (addr net.Addr)
// collectAddrs returns the slice of network addresses of the given listeners
// using the given addrFunc.
func collectAddrs[A any](listeners []A, af addrFunc[A]) (addrs []net.Addr) {
for _, l := range listeners {
addrs = append(addrs, af(l))
}
return addrs
}
// Addrs returns all listen addresses for the specified proto or nil if the
// proxy does not listen to it. proto must be one of [Proto]: [ProtoTCP],
// [ProtoUDP], [ProtoTLS], [ProtoHTTPS], [ProtoQUIC], or [ProtoDNSCrypt].
func (p *Proxy) Addrs(proto Proto) (addrs []net.Addr) {
p.RLock()
defer p.RUnlock()
switch proto {
case ProtoTCP:
return collectAddrs(p.tcpListen, net.Listener.Addr)
case ProtoTLS:
return collectAddrs(p.tlsListen, net.Listener.Addr)
case ProtoHTTPS:
return collectAddrs(p.httpsListen, net.Listener.Addr)
case ProtoUDP:
return collectAddrs(p.udpListen, (*net.UDPConn).LocalAddr)
case ProtoQUIC:
return collectAddrs(p.quicListen, (*quic.EarlyListener).Addr)
case ProtoDNSCrypt:
// Using only UDP addrs here
//
// TODO(ameshkov): To do it better we should either do
// ProtoDNSCryptTCP/ProtoDNSCryptUDP or we should change the
// configuration so that it was not possible to set different ports for
// TCP/UDP listeners.
return collectAddrs(p.dnsCryptUDPListen, (*net.UDPConn).LocalAddr)
default:
panic("proto must be 'tcp', 'tls', 'https', 'quic', 'dnscrypt' or 'udp'")
}
}
// firstAddr returns the network address of the first listener in the given
// listeners or nil using the given addrFunc.
func firstAddr[A any](listeners []A, af addrFunc[A]) (addr net.Addr) {
if len(listeners) == 0 {
return nil
}
return af(listeners[0])
}
// Addr returns the first listen address for the specified proto or nil if the
// proxy does not listen to it. proto must be one of [Proto]: [ProtoTCP],
// [ProtoUDP], [ProtoTLS], [ProtoHTTPS], [ProtoQUIC], or [ProtoDNSCrypt].
func (p *Proxy) Addr(proto Proto) (addr net.Addr) {
p.RLock()
defer p.RUnlock()
switch proto {
case ProtoTCP:
return firstAddr(p.tcpListen, net.Listener.Addr)
case ProtoTLS:
return firstAddr(p.tlsListen, net.Listener.Addr)
case ProtoHTTPS:
return firstAddr(p.httpsListen, net.Listener.Addr)
case ProtoUDP:
return firstAddr(p.udpListen, (*net.UDPConn).LocalAddr)
case ProtoQUIC:
return firstAddr(p.quicListen, (*quic.EarlyListener).Addr)
case ProtoDNSCrypt:
return firstAddr(p.dnsCryptUDPListen, (*net.UDPConn).LocalAddr)
default:
panic("proto must be 'tcp', 'tls', 'https', 'quic', 'dnscrypt' or 'udp'")
}
}
// selectUpstreams returns the upstreams to use for the specified host. It
// firstly considers custom upstreams if those aren't empty and then the
// configured ones. The returned slice may be empty or nil.
func (p *Proxy) selectUpstreams(d *DNSContext) (upstreams []upstream.Upstream, isPrivate bool) {
q := d.Req.Question[0]
host := q.Name
if d.RequestedPrivateRDNS != (netip.Prefix{}) || p.shouldStripDNS64(d.Req) {
// Use private upstreams.
private := p.PrivateRDNSUpstreamConfig
if p.UsePrivateRDNS && d.IsPrivateClient && private != nil {
// This may only be a PTR, SOA, and NS request.
upstreams = private.getUpstreamsForDomain(host)
}
return upstreams, true
}
getUpstreams := (*UpstreamConfig).getUpstreamsForDomain
if q.Qtype == dns.TypeDS {
getUpstreams = (*UpstreamConfig).getUpstreamsForDS
}
if custom := d.CustomUpstreamConfig; custom != nil {
// Try to use custom.
upstreams = getUpstreams(custom.upstream, host)
if len(upstreams) > 0 {
return upstreams, false
}
}
// Use configured.
return getUpstreams(p.UpstreamConfig, host), false
}
// replyFromUpstream tries to resolve the request via configured upstream
// servers. It returns true if the response actually came from an upstream.
func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) {
req := d.Req
upstreams, isPrivate := p.selectUpstreams(d)
if len(upstreams) == 0 {
d.Res = p.messages.NewMsgNXDOMAIN(req)
return false, fmt.Errorf("selecting upstream: %w", upstream.ErrNoUpstreams)
}
if isPrivate {
p.recDetector.add(d.Req)
}
src := "upstream"
wrapped := upstreamsWithStats(upstreams)
// Perform the DNS request.
resp, u, err := p.exchangeUpstreams(req, wrapped)
if dns64Ups := p.performDNS64(req, resp, wrapped); dns64Ups != nil {
u = dns64Ups
} else if p.isBogusNXDomain(resp) {
p.logger.Debug("response contains bogus-nxdomain ip")
resp = p.messages.NewMsgNXDOMAIN(req)
}
var wrappedFallbacks []upstream.Upstream
if err != nil && !isPrivate && p.Fallbacks != nil {
p.logger.Debug("using fallback", slogutil.KeyError, err)
src = "fallback"
// upstreams mustn't appear empty since they have been validated when
// creating proxy.
upstreams = p.Fallbacks.getUpstreamsForDomain(req.Question[0].Name)
wrappedFallbacks = upstreamsWithStats(upstreams)
resp, u, err = upstream.ExchangeParallel(wrappedFallbacks, req)
}
if err != nil {
p.logger.Debug("resolving err", "src", src, slogutil.KeyError, err)
}
if resp != nil {
p.logger.Debug("resolved", "src", src)
}
unwrapped, stats := collectQueryStats(p.UpstreamMode, u, wrapped, wrappedFallbacks)
d.queryStatistics = stats
p.handleExchangeResult(d, req, resp, unwrapped)
return resp != nil, err
}
// handleExchangeResult handles the result after the upstream exchange. It sets
// the response to d and sets the upstream that have resolved the request. If
// the response is nil, it generates a server failure response.
func (p *Proxy) handleExchangeResult(d *DNSContext, req, resp *dns.Msg, u upstream.Upstream) {
if resp == nil {
d.Res = p.messages.NewMsgSERVFAIL(req)
d.hasEDNS0 = false
return
}
d.Upstream = u
d.Res = resp
p.setMinMaxTTL(resp)
if len(req.Question) > 0 && len(resp.Question) == 0 {
// Explicitly construct the question section since some upstreams may
// respond with invalidly constructed messages which cause out-of-range
// panics afterwards.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3551.
resp.Question = []dns.Question{req.Question[0]}
}
}
// addDO adds EDNS0 RR if needed and sets DO bit of msg to true.
func addDO(msg *dns.Msg) {
if o := msg.IsEdns0(); o != nil {
if !o.Do() {
o.SetDo()
}
return
}
msg.SetEdns0(defaultUDPBufSize, true)
}
// defaultUDPBufSize defines the default size of UDP buffer for EDNS0 RRs.
const defaultUDPBufSize = 2048
// Resolve is the default resolving method used by the DNS proxy to query
// upstream servers. It expects dctx is filled with the request, the client's
func (p *Proxy) Resolve(dctx *DNSContext) (err error) {
if p.EnableEDNSClientSubnet {
dctx.processECS(p.EDNSAddr, p.logger)
}
dctx.calcFlagsAndSize()
// Also don't lookup the cache for responses with DNSSEC checking disabled
// since only validated responses are cached and those may be not the
// desired result for user specifying CD flag.
cacheWorks := p.cacheWorks(dctx)
if cacheWorks {
if p.replyFromCache(dctx) {
// Complete the response from cache.
dctx.scrub()
return nil
}
// On cache miss request for DNSSEC from the upstream to cache it
// afterwards.
addDO(dctx.Req)
}
var ok bool
ok, err = p.replyFromUpstream(dctx)
// Don't cache the responses having CD flag, just like Dnsmasq does. It
// prevents the cache from being poisoned with unvalidated answers which may
// differ from validated ones.
//
// See https://github.com/imp/dnsmasq/blob/770bce967cfc9967273d0acfb3ea018fb7b17522/src/forward.c#L1169-L1172.
if cacheWorks && ok && !dctx.Res.CheckingDisabled {
// Cache the response with DNSSEC RRs.
p.cacheResp(dctx)
}
// It is possible that the response is nil if the upstream hasn't been
// chosen.
if dctx.Res != nil {
filterMsg(dctx.Res, dctx.Res, dctx.adBit, dctx.doBit, 0)
}
// Complete the response.
dctx.scrub()
if p.ResponseHandler != nil {
p.ResponseHandler(dctx, err)
}
return err
}
// cacheWorks returns true if the cache works for the given context. If not, it
// returns false and logs the reason why.
func (p *Proxy) cacheWorks(dctx *DNSContext) (ok bool) {
var reason string
switch {
case p.cache == nil:
reason = "disabled"
case dctx.RequestedPrivateRDNS != netip.Prefix{}:
// Don't cache the requests intended for local upstream servers, those
// should be fast enough as is.
reason = "requested address is private"
case dctx.CustomUpstreamConfig != nil && dctx.CustomUpstreamConfig.cache == nil:
// In case of custom upstream cache is not configured, the global proxy
// cache cannot be used because different upstreams can return different
// results.
//
// See https://github.com/AdguardTeam/dnsproxy/issues/169.
//
// TODO(e.burkov): It probably should be decided after resolve.
reason = "custom upstreams cache is not configured"
case dctx.Req.CheckingDisabled:
reason = "dnssec check disabled"
default:
return true
}
p.logger.Debug("not caching", "reason", reason)
return false
}
// processECS adds EDNS Client Subnet data into the request from d.
func (dctx *DNSContext) processECS(cliIP net.IP, l *slog.Logger) {
if ecs, _ := ecsFromMsg(dctx.Req); ecs != nil {
if ones, _ := ecs.Mask.Size(); ones != 0 {
dctx.ReqECS = ecs
l.Debug("passing through ecs", "subnet", dctx.ReqECS)
return
}
}
var cliAddr netip.Addr
if cliIP == nil {
cliAddr = dctx.Addr.Addr()
cliIP = cliAddr.AsSlice()
} else {
cliAddr, _ = netip.AddrFromSlice(cliIP)
}
if !netutil.IsSpecialPurpose(cliAddr) {
// A Stub Resolver MUST set SCOPE PREFIX-LENGTH to 0. See RFC 7871
// Section 6.
dctx.ReqECS = setECS(dctx.Req, cliIP, 0)
l.Debug("setting ecs", "subnet", dctx.ReqECS)
}
}
07070100000069000081A4000000000000000000000001679A649F0000A28D000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/proxy_test.gopackage proxy
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"net/netip"
"net/url"
"sync"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/dnsproxy/upstream"
glcache "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
listenIP = "127.0.0.1"
testDefaultUpstreamAddr = "8.8.8.8:53"
tlsServerName = "testdns.adguard.com"
testMessagesCount = 10
// defaultTestTTL used to guarantee caching.
defaultTestTTL = 1000
// testTimeout is the default timeout for tests.
testTimeout = 500 * time.Millisecond
)
// localhostAnyPort is a [netip.AddrPort] having a value of 127.0.0.1:0.
var localhostAnyPort = netip.MustParseAddrPort(netutil.JoinHostPort(listenIP, 0))
// defaultTrustedProxies is a set of trusted proxies that includes all possible
// IP addresses.
var defaultTrustedProxies netutil.SubnetSet = netutil.SliceSubnetSet{
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("::0/0"),
}
// mustNew wraps [New] function failing the test on error.
func mustNew(t *testing.T, conf *Config) (p *Proxy) {
t.Helper()
p, err := New(conf)
require.NoError(t, err)
return p
}
// sendTestMessages sends [testMessagesCount] DNS requests to the specified
// connection and checks the responses.
func sendTestMessages(t *testing.T, conn *dns.Conn) {
for i := range testMessagesCount {
req := newTestMessage()
err := conn.WriteMsg(req)
require.NoErrorf(t, err, "req number %d", i)
res, err := conn.ReadMsg()
require.NoErrorf(t, err, "resp number %d", i)
requireResponse(t, req, res)
}
}
func newTestMessage() *dns.Msg {
return newHostTestMessage("google-public-dns-a.google.com")
}
func newHostTestMessage(host string) (req *dns.Msg) {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: host + ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
}
func requireResponse(t testing.TB, req, reply *dns.Msg) {
t.Helper()
require.NotNil(t, reply)
require.Lenf(t, reply.Answer, 1, "wrong number of answers: %d", len(reply.Answer))
require.Equal(t, req.Id, reply.Id)
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "wrong answer type: %v", reply.Answer[0])
require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}
func newTLSConfig(t *testing.T) (conf *tls.Config, certPem []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)
notBefore := time.Now()
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
keyUsage := x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{Organization: []string{"AdGuard Tests"}},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
DNSNames: []string{tlsServerName},
}
derBytes, err := x509.CreateCertificate(
rand.Reader,
&template,
&template,
&privateKey.PublicKey,
privateKey,
)
require.NoError(t, err)
certPem = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
})
keyPem := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
cert, err := tls.X509KeyPair(certPem, keyPem)
require.NoError(t, err)
return &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: tlsServerName}, certPem
}
// firstIP returns the first IP address from the DNS response.
func firstIP(resp *dns.Msg) (ip net.IP) {
for _, ans := range resp.Answer {
a, ok := ans.(*dns.A)
if !ok {
continue
}
return a.A
}
return nil
}
type testUpstream struct {
ans []dns.RR
ecsIP net.IP
ecsReqIP net.IP
ecsReqMask int
}
// type check
var _ upstream.Upstream = (*testUpstream)(nil)
// Exchange implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
if u.ans != nil {
resp.Answer = append(resp.Answer, u.ans...)
}
ecs, _ := ecsFromMsg(m)
if ecs != nil {
u.ecsReqIP = ecs.IP
u.ecsReqMask, _ = ecs.Mask.Size()
}
if u.ecsIP != nil {
setECS(resp, u.ecsIP, 24)
}
return resp, nil
}
// Address implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Address() (addr string) {
return ""
}
// Close implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Close() (err error) {
return nil
}
// newTestUpstreamConfigWithBoot creates a new UpstreamConfig with upstream
// addresses and a bootstrapped resolver.
func newTestUpstreamConfigWithBoot(
t require.TestingT,
timeout time.Duration,
addrs ...string,
) (u *UpstreamConfig) {
googleRslv, err := upstream.NewUpstreamResolver(
"8.8.8.8:53",
&upstream.Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: timeout,
},
)
require.NoError(t, err)
upsConf, err := ParseUpstreamsConfig(addrs, &upstream.Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: timeout,
Bootstrap: upstream.NewCachingResolver(googleRslv),
})
require.NoError(t, err)
return upsConf
}
// newTestUpstreamConfig creates a new UpstreamConfig with a single upstream
// address and default timeout.
func newTestUpstreamConfig(
t testing.TB,
timeout time.Duration,
addrs ...string,
) (u *UpstreamConfig) {
t.Helper()
upsConf, err := ParseUpstreamsConfig(addrs, &upstream.Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: timeout,
})
require.NoError(t, err)
return upsConf
}
// mustStartDefaultProxy starts a new proxy with default settings and returns
// it. It fails the test on error.
func mustStartDefaultProxy(t *testing.T) (p *Proxy) {
t.Helper()
p = mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
})
ctx := context.Background()
err := p.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) })
return p
}
// TestProxyRace sends multiple parallel DNS requests to the
// fully configured dnsproxy to check for race conditions
func TestProxyRace(t *testing.T) {
upsConf := newTestUpstreamConfig(
t,
defaultTimeout,
// Use the same upstream twice so that we could rotate them
testDefaultUpstreamAddr,
testDefaultUpstreamAddr,
)
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: upsConf,
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
})
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Create a DNS-over-UDP client connection
addr := dnsProxy.Addr(ProtoUDP)
conn, err := dns.Dial("udp", addr.String())
require.NoError(t, err)
g := &sync.WaitGroup{}
g.Add(testMessagesCount)
pt := testutil.PanicT{}
for range testMessagesCount {
go func() {
defer g.Done()
req := newTestMessage()
writeErr := conn.WriteMsg(req)
require.NoError(pt, writeErr)
res, readErr := conn.ReadMsg()
require.NoError(pt, readErr)
// We do not check if msg IDs match because the order of responses may
// be different.
require.NotNil(pt, res)
require.Len(pt, res.Answer, 1)
require.IsType(pt, &dns.A{}, res.Answer[0])
a := res.Answer[0].(*dns.A)
require.Equal(pt, net.IPv4(8, 8, 8, 8), a.A.To16())
}()
}
g.Wait()
}
// newTxts returns new test TXT RR strings.
func newTxts(t *testing.T, txtDataLen int) (txts []string) {
t.Helper()
const txtDataChunkLen = 255
txtDataChunkNum := txtDataLen / txtDataChunkLen
if txtDataLen%txtDataChunkLen > 0 {
txtDataChunkNum++
}
txts = make([]string, txtDataChunkNum)
randData := make([]byte, txtDataLen)
n, err := rand.Read(randData)
require.NoError(t, err)
require.Equal(t, txtDataLen, n)
for i, c := range randData {
randData[i] = c%26 + 'a'
}
// *dns.TXT requires splitting the actual data into 256-byte chunks.
for i := range txtDataChunkNum {
r := txtDataChunkLen * (i + 1)
if r > txtDataLen {
r = txtDataLen
}
txts[i] = string(randData[txtDataChunkLen*i : r])
}
return txts
}
// newDNSContext returns new DNS request message context with Proto set to
// [ProtoUDP]. Constructs request message from the given parameters.
func newDNSContext(
domain string,
qtype uint16,
qclass uint16,
edns bool,
udpsize uint16,
) (dctx *DNSContext) {
req := newReq(domain, qtype, qclass)
if edns {
req.SetEdns0(udpsize, true)
}
return &DNSContext{
Req: req,
Proto: ProtoUDP,
}
}
// newReq returns new request message for provided parameters.
func newReq(domain string, qtype, qclass uint16) (req *dns.Msg) {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
},
Compress: true,
Question: []dns.Question{{
Name: dns.Fqdn(domain),
Qtype: qtype,
Qclass: qclass,
}},
}
}
func TestProxy_Resolve_dnssecCache(t *testing.T) {
const (
host = "example.com"
// Larger than UDP buffer size to invoke truncation.
txtDataLen = 1024
)
txt := &dns.TXT{
Hdr: dns.RR_Header{
Name: dns.Fqdn(host),
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
},
Txt: newTxts(t, txtDataLen),
}
a := &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn(host),
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.IP{1, 2, 3, 4},
}
ds := &dns.DS{
Hdr: dns.RR_Header{
Name: dns.Fqdn(host),
Rrtype: dns.TypeDS,
Class: dns.ClassINET,
},
Digest: "736f6d652064656c65676174696f6e207369676e6572",
}
rrsig := &dns.RRSIG{
Hdr: dns.RR_Header{
Name: dns.Fqdn(host),
Rrtype: dns.TypeRRSIG,
Class: dns.ClassINET,
Ttl: defaultTestTTL,
},
TypeCovered: dns.TypeA,
Algorithm: 8,
Labels: 2,
SignerName: dns.Fqdn(host),
Signature: "c29tZSBycnNpZyByZWxhdGVkIHN0dWZm",
}
u := &fakeUpstream{
onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
resp = (&dns.Msg{}).SetReply(m)
q := m.Question[0]
switch q.Qtype {
case dns.TypeA:
resp.Answer = append(resp.Answer, a)
case dns.TypeTXT:
resp.Answer = append(resp.Answer, txt)
case dns.TypeDS:
resp.Answer = append(resp.Answer, ds)
default:
// Go on. The RRSIG resource record is added afterward. This
// upstream.Upstream implementation doesn't handle explicit
// requests for it.
}
if len(resp.Answer) > 0 {
resp.Answer[0].Header().Ttl = defaultTestTTL
}
if o := m.IsEdns0(); o != nil {
resp.Answer = append(resp.Answer, rrsig)
resp.SetEdns0(defaultUDPBufSize, o.Do())
}
return resp, nil
},
onAddress: func() (addr string) { return "" },
onClose: func() (err error) { return nil },
}
p := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: &UpstreamConfig{Upstreams: []upstream.Upstream{u}},
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
CacheEnabled: true,
CacheSizeBytes: defaultCacheSize,
})
testCases := []struct {
wantAns dns.RR
name string
wantLen int
edns bool
}{{
wantAns: a,
name: "a_noedns",
wantLen: 1,
edns: false,
}, {
wantAns: a,
name: "a_ends",
wantLen: 2,
edns: true,
}, {
wantAns: txt,
name: "txt_noedns",
wantLen: 1,
edns: false,
}, {
wantAns: txt,
name: "txt_edns",
// Truncated.
wantLen: 0,
edns: true,
}, {
wantAns: ds,
name: "ds_noedns",
wantLen: 1,
edns: false,
}, {
wantAns: ds,
name: "ds_edns",
wantLen: 2,
edns: true,
}}
for _, tc := range testCases {
ansHdr := tc.wantAns.Header()
dctx := newDNSContext(ansHdr.Name, ansHdr.Rrtype, ansHdr.Class, tc.edns, txtDataLen/2)
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(p.cache.items.Clear)
err := p.Resolve(dctx)
require.NoError(t, err)
res := dctx.Res
require.NotNil(t, res)
require.Len(t, res.Answer, tc.wantLen, res.Answer)
switch tc.wantLen {
case 0:
assert.True(t, res.Truncated)
case 1:
res.Answer[0].Header().Ttl = defaultTestTTL
assert.Equal(t, tc.wantAns, res.Answer[0])
case 2:
res.Answer[0].Header().Ttl = defaultTestTTL
assert.Equal(t, tc.wantAns, res.Answer[0])
assert.Equal(t, rrsig, res.Answer[1])
default:
t.Fatalf("wanted length has unexpected value %d", tc.wantLen)
}
cached, expired, key := p.cache.get(dctx.Req)
require.NotNil(t, cached)
require.Len(t, cached.m.Answer, 2)
assert.False(t, expired)
assert.Equal(t, key, msgToKey(dctx.Req))
// Just make it match.
cached.m.Answer[0].Header().Ttl = defaultTestTTL
assert.Equal(t, tc.wantAns.String(), cached.m.Answer[0].String())
assert.Equal(t, rrsig.String(), cached.m.Answer[1].String())
})
}
}
func TestExchangeWithReservedDomains(t *testing.T) {
t.Parallel()
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfigWithBoot(
t,
testTimeout,
"[/adguard.com/]1.2.3.4",
"[/google.ru/]2.3.4.5",
"[/maps.google.ru/]#",
"1.1.1.1",
),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
})
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Create a DNS-over-TCP client connection.
addr := dnsProxy.Addr(ProtoTCP)
conn, err := dns.Dial("tcp", addr.String())
require.NoError(t, err)
// Create google-a test message.
req := newTestMessage()
err = conn.WriteMsg(req)
require.NoError(t, err)
// Make sure that dnsproxy is working.
res, err := conn.ReadMsg()
require.NoError(t, err)
requireResponse(t, req, res)
// Create adguard.com test message.
req = newHostTestMessage("adguard.com")
err = conn.WriteMsg(req)
require.NoError(t, err)
// Test message should not be resolved.
res, _ = conn.ReadMsg()
require.Nil(t, res.Answer)
// Create www.google.ru test message.
req = newHostTestMessage("www.google.ru")
err = conn.WriteMsg(req)
require.NoError(t, err)
// Test message should not be resolved.
res, _ = conn.ReadMsg()
require.Empty(t, res.Answer)
// Create maps.google.ru test message.
req = newHostTestMessage("maps.google.ru")
err = conn.WriteMsg(req)
require.NoError(t, err)
// Test message should be resolved.
res, _ = conn.ReadMsg()
require.NotNil(t, res.Answer)
}
// TestOneByOneUpstreamsExchange tries to resolve DNS request
// with one valid and two invalid upstreams
func TestOneByOneUpstreamsExchange(t *testing.T) {
t.Parallel()
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfigWithBoot(
t,
testTimeout,
"https://fake-dns.com/fake-dns-query",
"tls://fake-dns.com",
"1.1.1.1",
),
TrustedProxies: defaultTrustedProxies,
Fallbacks: newTestUpstreamConfig(t, testTimeout, "1.2.3.4:567"),
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
})
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// create a DNS-over-TCP client connection
addr := dnsProxy.Addr(ProtoTCP)
conn, err := dns.Dial("tcp", addr.String())
require.NoError(t, err)
// make sure that the response is okay and resolved by valid upstream
req := newTestMessage()
err = conn.WriteMsg(req)
require.NoError(t, err)
start := time.Now()
res, err := conn.ReadMsg()
require.NoError(t, err)
requireResponse(t, req, res)
elapsed := time.Since(start)
assert.Greater(t, 3*testTimeout, elapsed)
}
// newLocalUpstreamListener creates a new localhost listener on the specified
// port for tcp4 network and returns its listening address.
func newLocalUpstreamListener(t *testing.T, port uint16, h dns.Handler) (real netip.AddrPort) {
t.Helper()
startCh := make(chan struct{})
upsSrv := &dns.Server{
Addr: netip.AddrPortFrom(netutil.IPv4Localhost(), port).String(),
Net: "tcp",
Handler: h,
NotifyStartedFunc: func() { close(startCh) },
}
go func() {
err := upsSrv.ListenAndServe()
require.NoError(testutil.PanicT{}, err)
}()
<-startCh
testutil.CleanupAndRequireSuccess(t, upsSrv.Shutdown)
return testutil.RequireTypeAssert[*net.TCPAddr](t, upsSrv.Listener.Addr()).AddrPort()
}
func TestFallback(t *testing.T) {
t.Parallel()
responseCh := make(chan uint16)
failCh := make(chan uint16)
successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
testutil.RequireSend(testutil.PanicT{}, responseCh, r.Id, testTimeout)
require.NoError(testutil.PanicT{}, w.WriteMsg((&dns.Msg{}).SetReply(r)))
})
failHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
testutil.RequireSend(testutil.PanicT{}, failCh, r.Id, testTimeout)
require.NoError(testutil.PanicT{}, w.WriteMsg(&dns.Msg{}))
})
successAddr := (&url.URL{
Scheme: string(ProtoTCP),
Host: newLocalUpstreamListener(t, 0, successHandler).String(),
}).String()
alsoSuccessAddr := (&url.URL{
Scheme: string(ProtoTCP),
Host: newLocalUpstreamListener(t, 0, successHandler).String(),
}).String()
failAddr := (&url.URL{
Scheme: string(ProtoTCP),
Host: newLocalUpstreamListener(t, 0, failHandler).String(),
}).String()
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(
t,
testTimeout,
failAddr,
"[/specific.example/]"+alsoSuccessAddr,
// almost.failing.example will fall here first.
"[/failing.example/]"+failAddr,
),
TrustedProxies: defaultTrustedProxies,
Fallbacks: newTestUpstreamConfig(
t,
testTimeout,
failAddr,
successAddr,
"[/failing.example/]"+failAddr,
"[/almost.failing.example/]"+alsoSuccessAddr,
),
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
})
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
conn, err := dns.Dial("tcp", dnsProxy.Addr(ProtoTCP).String())
require.NoError(t, err)
testCases := []struct {
name string
wantSignals []chan uint16
}{{
name: "general.example",
wantSignals: []chan uint16{
failCh,
// Both non-specific fallbacks tried.
failCh,
responseCh,
},
}, {
name: "specific.example",
wantSignals: []chan uint16{
responseCh,
},
}, {
name: "failing.example",
wantSignals: []chan uint16{
failCh,
failCh,
},
}, {
name: "almost.failing.example",
wantSignals: []chan uint16{
failCh,
responseCh,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := newHostTestMessage(tc.name)
err = conn.WriteMsg(req)
require.NoError(t, err)
for _, ch := range tc.wantSignals {
reqID, ok := testutil.RequireReceive(testutil.PanicT{}, ch, testTimeout)
require.True(t, ok)
assert.Equal(t, req.Id, reqID)
}
_, err = conn.ReadMsg()
require.NoError(t, err)
})
}
}
func TestFallbackFromInvalidBootstrap(t *testing.T) {
t.Parallel()
invalidRslv, err := upstream.NewUpstreamResolver("8.8.8.8:555", &upstream.Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: testTimeout,
})
require.NoError(t, err)
// Prepare the proxy server
upsConf, err := ParseUpstreamsConfig([]string{"tls://dns.adguard.com"}, &upstream.Options{
Logger: slogutil.NewDiscardLogger(),
Bootstrap: invalidRslv, Timeout: testTimeout,
})
require.NoError(t, err)
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: upsConf,
TrustedProxies: defaultTrustedProxies,
Fallbacks: newTestUpstreamConfig(
t,
testTimeout,
"1.0.0.1",
"8.8.8.8",
),
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
})
// Start listening
ctx := context.Background()
err = dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Create a DNS-over-UDP client connection
addr := dnsProxy.Addr(ProtoUDP)
conn, err := dns.Dial("udp", addr.String())
require.NoError(t, err)
// Make sure that the response is okay and resolved by the fallback
req := newTestMessage()
err = conn.WriteMsg(req)
require.NoError(t, err)
start := time.Now()
res, err := conn.ReadMsg()
require.NoError(t, err)
requireResponse(t, req, res)
elapsed := time.Since(start)
assert.Greater(t, 3*testTimeout, elapsed)
}
func TestRefuseAny(t *testing.T) {
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
RefuseAny: true,
})
// Start listening
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Create a DNS-over-UDP client connection
addr := dnsProxy.Addr(ProtoUDP)
client := &dns.Client{
Net: string(ProtoUDP),
Timeout: testTimeout,
}
// Create a DNS request
request := (&dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
}).SetQuestion("google.com.", dns.TypeANY)
r, _, err := client.Exchange(request, addr.String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeNotImplemented, r.Rcode)
}
func TestInvalidDNSRequest(t *testing.T) {
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
RefuseAny: true,
})
// Start listening
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Create a DNS-over-UDP client connection
client := &dns.Client{
Net: string(ProtoUDP),
Timeout: testTimeout,
}
// Create a DNS request
request := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
}
r, _, err := client.Exchange(request, dnsProxy.Addr(ProtoUDP).String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeServerFailure, r.Rcode)
}
// Server must drop incoming Response messages
func TestResponseInRequest(t *testing.T) {
dnsProxy := mustStartDefaultProxy(t)
addr := dnsProxy.Addr(ProtoUDP)
client := &dns.Client{
Net: string(ProtoUDP),
Timeout: testTimeout,
}
req := newTestMessage()
req.Response = true
r, _, err := client.Exchange(req, addr.String())
netErr := &net.OpError{}
require.ErrorAs(t, err, &netErr)
assert.True(t, netErr.Timeout())
assert.Nil(t, r)
}
// Server must respond with SERVFAIL to requests without a Question
func TestNoQuestion(t *testing.T) {
dnsProxy := mustStartDefaultProxy(t)
addr := dnsProxy.Addr(ProtoUDP)
client := &dns.Client{
Net: string(ProtoUDP),
Timeout: testTimeout,
}
req := newTestMessage()
req.Question = nil
r, _, err := client.Exchange(req, addr.String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeServerFailure, r.Rcode)
}
// fakeUpstream is a mock upstream implementation to simplify testing. It
// allows assigning custom Exchange and Address methods.
//
// TODO(e.burkov): Use dnsproxytest.FakeUpstream instead.
type fakeUpstream struct {
onExchange func(m *dns.Msg) (resp *dns.Msg, err error)
onAddress func() (addr string)
onClose func() (err error)
}
// type check
var _ upstream.Upstream = (*fakeUpstream)(nil)
// Exchange implements upstream.Upstream interface for *funcUpstream.
func (u *fakeUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return u.onExchange(m) }
// Address implements upstream.Upstream interface for *funcUpstream.
func (u *fakeUpstream) Address() (addr string) { return u.onAddress() }
// Close implements upstream.Upstream interface for *funcUpstream.
func (u *fakeUpstream) Close() (err error) { return u.onClose() }
func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) {
dnsProxy := mustStartDefaultProxy(t)
u := &fakeUpstream{
onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
resp = (&dns.Msg{}).SetReply(m)
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: m.Question[0].Name,
Class: dns.ClassINET,
Rrtype: dns.TypeA,
},
A: net.IP{1, 2, 3, 4},
})
// Make the response invalid.
resp.Question = []dns.Question{}
return resp, nil
},
onAddress: func() (addr string) { return "stub" },
onClose: func() error { panic("not implemented") },
}
d := &DNSContext{
CustomUpstreamConfig: NewCustomUpstreamConfig(
&UpstreamConfig{Upstreams: []upstream.Upstream{u}},
false,
0,
false,
),
Req: newHostTestMessage("host"),
Addr: netip.MustParseAddrPort("1.2.3.0:1234"),
}
var err error
require.NotPanics(t, func() {
err = dnsProxy.Resolve(d)
})
require.NoError(t, err)
assert.Equal(t, d.Req.Question[0], d.Res.Question[0])
}
func TestExchangeCustomUpstreamConfig(t *testing.T) {
prx := mustStartDefaultProxy(t)
ansIP := net.IP{4, 3, 2, 1}
u := &testUpstream{
ans: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Rrtype: dns.TypeA,
Name: "host.",
Ttl: 60,
},
A: ansIP,
}},
}
d := DNSContext{
CustomUpstreamConfig: NewCustomUpstreamConfig(
&UpstreamConfig{Upstreams: []upstream.Upstream{u}},
false,
0,
false,
),
Req: newHostTestMessage("host"),
Addr: netip.MustParseAddrPort("1.2.3.0:1234"),
}
err := prx.Resolve(&d)
require.NoError(t, err)
assert.Equal(t, ansIP, firstIP(d.Res))
}
func TestExchangeCustomUpstreamConfigCache(t *testing.T) {
prx := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
CacheEnabled: true,
})
ctx := context.Background()
err := prx.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return prx.Shutdown(ctx) })
var count int
ansIP := net.IP{4, 3, 2, 1}
exchangeFunc := func(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: m.Question[0].Name,
Class: dns.ClassINET,
Rrtype: dns.TypeA,
Ttl: defaultTestTTL,
},
A: ansIP,
})
count++
return resp, nil
}
u := &fakeUpstream{
onExchange: exchangeFunc,
onAddress: func() (addr string) { return "stub" },
onClose: func() error { panic("not implemented") },
}
customUpstreamConfig := NewCustomUpstreamConfig(
&UpstreamConfig{Upstreams: []upstream.Upstream{u}},
true,
defaultCacheSize,
prx.EnableEDNSClientSubnet,
)
d := DNSContext{
CustomUpstreamConfig: customUpstreamConfig,
Req: newHostTestMessage("host"),
Addr: netip.MustParseAddrPort("1.2.3.0:1234"),
}
err = prx.Resolve(&d)
require.NoError(t, err)
require.Equal(t, 1, count)
assert.Equal(t, ansIP, firstIP(d.Res))
err = prx.Resolve(&d)
require.NoError(t, err)
assert.Equal(t, 1, count)
assert.Equal(t, ansIP, firstIP(d.Res))
customUpstreamConfig.ClearCache()
err = prx.Resolve(&d)
require.NoError(t, err)
assert.Equal(t, 2, count)
assert.Equal(t, ansIP, firstIP(d.Res))
}
func TestECS(t *testing.T) {
t.Run("ipv4", func(t *testing.T) {
ip := net.IP{1, 2, 3, 4}
m := &dns.Msg{}
subnet := setECS(m, ip, 16)
ones, _ := subnet.Mask.Size()
assert.Equal(t, 24, ones)
var scope int
subnet, scope = ecsFromMsg(m)
assert.Equal(t, ip.Mask(subnet.Mask), subnet.IP)
ones, _ = subnet.Mask.Size()
assert.Equal(t, 24, ones)
assert.Equal(t, 16, scope)
})
t.Run("ipv6", func(t *testing.T) {
ip := net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
m := &dns.Msg{}
subnet := setECS(m, ip, 48)
ones, _ := subnet.Mask.Size()
assert.Equal(t, 56, ones)
var scope int
subnet, scope = ecsFromMsg(m)
assert.Equal(t, ip.Mask(subnet.Mask), subnet.IP)
ones, _ = subnet.Mask.Size()
assert.Equal(t, 56, ones)
assert.Equal(t, 48, scope)
})
}
// Resolve the same host with the different client subnet values
func TestECSProxy(t *testing.T) {
var (
ip1230 = net.IP{1, 2, 3, 0}
ip2230 = net.IP{2, 2, 3, 0}
ip4321 = net.IP{4, 3, 2, 1}
ip4322 = net.IP{4, 3, 2, 2}
ip4323 = net.IP{4, 3, 2, 3}
)
u := &testUpstream{
ans: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60},
A: ip4321,
}},
ecsIP: ip1230,
}
prx := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{u},
},
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
EnableEDNSClientSubnet: true,
CacheEnabled: true,
})
ctx := context.Background()
err := prx.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return prx.Shutdown(ctx) })
t.Run("cache_subnet", func(t *testing.T) {
d := DNSContext{
Req: newHostTestMessage("host"),
Addr: netip.MustParseAddrPort("1.2.3.0:1234"),
}
err = prx.Resolve(&d)
require.NoError(t, err)
assert.Equal(t, net.IP{4, 3, 2, 1}, firstIP(d.Res))
assert.Equal(t, ip1230, u.ecsReqIP)
})
t.Run("serve_subnet_cache", func(t *testing.T) {
d := &DNSContext{
Req: newHostTestMessage("host"),
Addr: netip.MustParseAddrPort("1.2.3.1:1234"),
}
u.ans, u.ecsIP, u.ecsReqIP = nil, nil, nil
require.NoError(t, prx.Resolve(d))
assert.Equal(t, ip4321, firstIP(d.Res))
assert.Nil(t, u.ecsReqIP)
})
t.Run("another_subnet", func(t *testing.T) {
d := DNSContext{
Req: newHostTestMessage("host"),
Addr: netip.MustParseAddrPort("2.2.3.0:1234"),
}
u.ans = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60},
A: ip4322,
}}
u.ecsIP = ip2230
err = prx.Resolve(&d)
require.NoError(t, err)
assert.Equal(t, ip4322, firstIP(d.Res))
assert.Equal(t, ip2230, u.ecsReqIP)
})
t.Run("cache_general", func(t *testing.T) {
d := DNSContext{
Req: newHostTestMessage("host"),
Addr: netip.MustParseAddrPort("127.0.0.1:1234"),
}
u.ans = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60},
A: ip4323,
}}
u.ecsIP, u.ecsReqIP = nil, nil
err = prx.Resolve(&d)
require.NoError(t, err)
assert.Equal(t, ip4323, firstIP(d.Res))
assert.Nil(t, u.ecsReqIP)
})
t.Run("serve_general_cache", func(t *testing.T) {
d := DNSContext{
Req: newHostTestMessage("host"),
Addr: netip.MustParseAddrPort("127.0.0.2:1234"),
}
u.ans, u.ecsIP, u.ecsReqIP = nil, nil, nil
err = prx.Resolve(&d)
require.NoError(t, err)
assert.Equal(t, ip4323, firstIP(d.Res))
assert.Nil(t, u.ecsReqIP)
})
}
func TestECSProxyCacheMinMaxTTL(t *testing.T) {
clientIP := net.IP{1, 2, 3, 0}
u := &testUpstream{
ans: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Rrtype: dns.TypeA,
Name: "host.",
Ttl: 10,
},
A: net.IP{4, 3, 2, 1},
}},
ecsIP: clientIP,
}
prx := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: &UpstreamConfig{Upstreams: []upstream.Upstream{u}},
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
EnableEDNSClientSubnet: true,
CacheEnabled: true,
CacheMinTTL: 20,
CacheMaxTTL: 40,
})
ctx := context.Background()
err := prx.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return prx.Shutdown(ctx) })
// first request
d := DNSContext{
Req: newHostTestMessage("host"),
Addr: netip.MustParseAddrPort("1.2.3.0:1234"),
}
err = prx.Resolve(&d)
require.NoError(t, err)
// get from cache - check min TTL
ci, expired, key := prx.cache.getWithSubnet(d.Req, &net.IPNet{
IP: clientIP,
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
})
assert.False(t, expired)
assert.Equal(t, key, msgToKeyWithSubnet(d.Req, clientIP, 24))
assert.True(t, ci.m.Answer[0].Header().Ttl == prx.CacheMinTTL)
// 2nd request
clientIP = net.IP{1, 2, 4, 0}
d.Req = newHostTestMessage("host")
d.Addr = netip.MustParseAddrPort("1.2.4.0:1234")
u.ans = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Rrtype: dns.TypeA,
Name: "host.",
Ttl: 60,
},
A: net.IP{4, 3, 2, 1},
}}
u.ecsIP = clientIP
err = prx.Resolve(&d)
require.NoError(t, err)
// get from cache - check max TTL
ci, expired, key = prx.cache.getWithSubnet(d.Req, &net.IPNet{
IP: clientIP,
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
})
assert.False(t, expired)
assert.Equal(t, key, msgToKeyWithSubnet(d.Req, clientIP, 24))
assert.True(t, ci.m.Answer[0].Header().Ttl == prx.CacheMaxTTL)
}
func TestProxy_Resolve_withOptimisticResolver(t *testing.T) {
const (
host = "some.domain.name."
nonOptimisticTTL = 3600
)
buildCtx := func() (dctx *DNSContext) {
req := &dns.Msg{
MsgHdr: dns.MsgHdr{Id: dns.Id()},
Question: []dns.Question{{
Name: host,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
return &DNSContext{Req: req}
}
buildResp := func(req *dns.Msg, ttl uint32) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetReply(req)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: host,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: net.IP{1, 2, 3, 4},
}}
return resp
}
p := &Proxy{
Config: Config{
CacheEnabled: true,
CacheOptimistic: true,
},
logger: slogutil.NewDiscardLogger(),
}
p.initCache()
out, in := make(chan unit), make(chan unit)
p.shortFlighter.cr = &testCachingResolver{
onReplyFromUpstream: func(dctx *DNSContext) (ok bool, err error) {
dctx.Res = buildResp(dctx.Req, nonOptimisticTTL)
return true, nil
},
onCacheResp: func(dctx *DNSContext) {
// Report adding to cache is in process.
out <- unit{}
// Wait for tests to finish.
<-in
p.cacheResp(dctx)
// Report adding tocache is finished.
out <- unit{}
},
}
// Two different contexts are made to emulate two different requests
// with the same question section.
firstCtx, secondCtx := buildCtx(), buildCtx()
// Add expired response into cache.
req := firstCtx.Req
key := msgToKey(req)
data := (&cacheItem{
m: buildResp(req, 0),
u: testUpsAddr,
}).pack()
items := glcache.New(glcache.Config{
EnableLRU: true,
})
items.Set(key, data)
p.cache.items = items
err := p.Resolve(firstCtx)
require.NoError(t, err)
require.Len(t, firstCtx.Res.Answer, 1)
assert.EqualValues(t, optimisticTTL, firstCtx.Res.Answer[0].Header().Ttl)
// Wait for optimisticResolver to reach the tested function.
<-out
err = p.Resolve(secondCtx)
require.NoError(t, err)
require.Len(t, secondCtx.Res.Answer, 1)
assert.EqualValues(t, optimisticTTL, secondCtx.Res.Answer[0].Header().Ttl)
// Continue and wait for it to finish.
in <- unit{}
<-out
// Should be served from cache.
data = p.cache.items.Get(msgToKey(firstCtx.Req))
unpacked, expired := p.cache.unpackItem(data, firstCtx.Req)
require.False(t, expired)
require.NotNil(t, unpacked)
require.Len(t, unpacked.m.Answer, 1)
assert.EqualValues(t, nonOptimisticTTL, unpacked.m.Answer[0].Header().Ttl)
}
func TestProxy_HandleDNSRequest_private(t *testing.T) {
t.Parallel()
privateSet := netutil.SubnetSetFunc(netutil.IsLocallyServed)
localIP := netip.MustParseAddrPort("192.168.0.1:1")
require.True(t, privateSet.Contains(localIP.Addr()))
externalIP := netip.MustParseAddrPort("4.3.2.1:1")
require.False(t, privateSet.Contains(externalIP.Addr()))
privateReq := (&dns.Msg{}).SetQuestion("2.0.168.192.in-addr.arpa", dns.TypePTR)
privateResp := (&dns.Msg{}).SetReply(privateReq)
privateResp.Compress = true
externalReq := (&dns.Msg{}).SetQuestion("2.2.3.4.in-addr.arpa", dns.TypePTR)
externalResp := (&dns.Msg{}).SetReply(externalReq)
externalResp.Compress = true
nxdomainResp := (&dns.Msg{}).SetReply(privateReq)
nxdomainResp.Rcode = dns.RcodeNameError
generalUps := &fakeUpstream{
onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
return externalResp.Copy(), nil
},
onAddress: func() (addr string) { return "general" },
onClose: func() (err error) { return nil },
}
privateUps := &fakeUpstream{
onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
return privateResp.Copy(), nil
},
onAddress: func() (addr string) { return "private" },
onClose: func() (err error) { return nil },
}
messages := dnsproxytest.NewTestMessageConstructor()
messages.OnNewMsgNXDOMAIN = func(_ *dns.Msg) (resp *dns.Msg) {
return nxdomainResp
}
p := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{generalUps},
},
PrivateRDNSUpstreamConfig: &UpstreamConfig{
Upstreams: []upstream.Upstream{privateUps},
},
PrivateSubnets: privateSet,
UsePrivateRDNS: true,
MessageConstructor: messages,
})
ctx := context.Background()
require.NoError(t, p.Start(ctx))
testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) })
testCases := []struct {
name string
want *dns.Msg
req *dns.Msg
cliAddr netip.AddrPort
}{{
name: "local_requests_external",
want: externalResp,
req: externalReq,
cliAddr: localIP,
}, {
name: "external_requests_external",
want: externalResp,
req: externalReq,
cliAddr: externalIP,
}, {
name: "local_requests_private",
want: privateResp,
req: privateReq,
cliAddr: localIP,
}, {
name: "external_requests_private",
want: nxdomainResp,
req: privateReq,
cliAddr: externalIP,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
dctx := p.newDNSContext(ProtoUDP, tc.req, tc.cliAddr)
require.NoError(t, p.handleDNSRequest(dctx))
assert.Equal(t, tc.want, dctx.Res)
})
}
}
0707010000006A000081A4000000000000000000000001679A649F00000ED3000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/proxycache.gopackage proxy
import (
"net"
"slices"
)
// cacheForContext returns cache object for the given context.
func (p *Proxy) cacheForContext(d *DNSContext) (c *cache) {
if d.CustomUpstreamConfig != nil && d.CustomUpstreamConfig.cache != nil {
return d.CustomUpstreamConfig.cache
}
return p.cache
}
// replyFromCache tries to get the response from general or subnet cache. In
// case the cache is present in d, it's used first. Returns true on success.
func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) {
dctxCache := p.cacheForContext(d)
var ci *cacheItem
var cacheSource string
var expired bool
var key []byte
// TODO(d.kolyshev): Use EnableEDNSClientSubnet from dctxCache.
if p.Config.EnableEDNSClientSubnet && d.ReqECS != nil {
ci, expired, key = dctxCache.getWithSubnet(d.Req, d.ReqECS)
cacheSource = "subnet cache"
} else {
ci, expired, key = dctxCache.get(d.Req)
cacheSource = "general cache"
}
if hit = ci != nil; !hit {
return hit
}
d.Res = ci.m
d.queryStatistics = cachedQueryStatistics(ci.u)
p.logger.Debug(
"replying from cache",
"source", cacheSource,
"ecs_enabled", p.Config.EnableEDNSClientSubnet,
)
if dctxCache.optimistic && expired {
// Build a reduced clone of the current context to avoid data race.
minCtxClone := &DNSContext{
// It is only read inside the optimistic resolver.
CustomUpstreamConfig: d.CustomUpstreamConfig,
ReqECS: cloneIPNet(d.ReqECS),
IsPrivateClient: d.IsPrivateClient,
}
if d.Req != nil {
minCtxClone.Req = d.Req.Copy()
addDO(minCtxClone.Req)
}
go p.shortFlighter.resolveOnce(minCtxClone, key, p.logger)
}
return hit
}
// cloneIPNet returns a deep clone of n.
func cloneIPNet(n *net.IPNet) (clone *net.IPNet) {
if n == nil {
return nil
}
return &net.IPNet{
IP: slices.Clone(n.IP),
Mask: slices.Clone(n.Mask),
}
}
// cacheResp stores the response from d in general or subnet cache. In case the
// cache is present in d, it's used first.
func (p *Proxy) cacheResp(d *DNSContext) {
dctxCache := p.cacheForContext(d)
if !p.EnableEDNSClientSubnet {
dctxCache.set(d.Res, d.Upstream, p.logger)
return
}
switch ecs, scope := ecsFromMsg(d.Res); {
case ecs != nil && d.ReqECS != nil:
ones, bits := ecs.Mask.Size()
reqOnes, _ := d.ReqECS.Mask.Size()
// If FAMILY, SOURCE PREFIX-LENGTH, and SOURCE PREFIX-LENGTH bits of
// ADDRESS in the response don't match the non-zero fields in the
// corresponding query, the full response MUST be dropped.
//
// See RFC 7871 Section 7.3.
//
// TODO(a.meshkov): The whole response MUST be dropped if ECS in it
// doesn't correspond.
if !ecs.IP.Mask(ecs.Mask).Equal(d.ReqECS.IP.Mask(d.ReqECS.Mask)) || ones != reqOnes {
p.logger.Debug(
"not caching response; subnet mismatch",
"ecs", ecs,
"req_ecs", d.ReqECS,
)
return
}
// If SCOPE PREFIX-LENGTH is not longer than SOURCE PREFIX-LENGTH, store
// SCOPE PREFIX-LENGTH bits of ADDRESS, and then mark the response as
// valid for all addresses that fall within that range.
//
// See RFC 7871 Section 7.3.1.
if scope < reqOnes {
ecs.Mask = net.CIDRMask(scope, bits)
ecs.IP = ecs.IP.Mask(ecs.Mask)
}
p.logger.Debug("caching response", "ecs", ecs)
dctxCache.setWithSubnet(d.Res, d.Upstream, ecs, p.logger)
case d.ReqECS != nil:
// Cache the response for all subnets since the server doesn't support
// EDNS Client Subnet option.
dctxCache.setWithSubnet(d.Res, d.Upstream, &net.IPNet{IP: nil, Mask: nil}, p.logger)
default:
dctxCache.set(d.Res, d.Upstream, p.logger)
}
}
// ClearCache clears the DNS cache of p.
func (p *Proxy) ClearCache() {
if p.cache == nil {
return
}
p.cache.clearItems()
p.cache.clearItemsWithSubnet()
p.logger.Debug("cache cleared")
}
0707010000006B000081A4000000000000000000000001679A649F000005CA000000000000000000000000000000000000002300000000dnsproxy-0.75.0/proxy/ratelimit.gopackage proxy
import (
"fmt"
"net/netip"
"slices"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
rate "github.com/beefsack/go-rate"
gocache "github.com/patrickmn/go-cache"
)
func (p *Proxy) limiterForIP(ip string) interface{} {
p.ratelimitLock.Lock()
defer p.ratelimitLock.Unlock()
if p.ratelimitBuckets == nil {
p.ratelimitBuckets = gocache.New(time.Hour, time.Hour)
}
// check if ratelimiter for that IP already exists, if not, create
value, found := p.ratelimitBuckets.Get(ip)
if !found {
value = rate.New(p.Ratelimit, time.Second)
p.ratelimitBuckets.Set(ip, value, time.Hour)
}
return value
}
func (p *Proxy) isRatelimited(addr netip.Addr) (ok bool) {
if p.Ratelimit <= 0 {
// The ratelimit is disabled.
return false
}
addr = addr.Unmap()
// Already sorted by [Proxy.Init].
_, ok = slices.BinarySearchFunc(p.RatelimitWhitelist, addr, netip.Addr.Compare)
if ok {
return false
}
var pref netip.Prefix
if addr.Is4() {
pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv4)
} else {
pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv6)
}
pref = pref.Masked()
// TODO(s.chzhen): Improve caching. Decrease allocations.
ipStr := pref.Addr().String()
value := p.limiterForIP(ipStr)
rl, ok := value.(*rate.RateLimiter)
if !ok {
p.logger.Error(
"invalid value found in ratelimit cache",
slogutil.KeyError,
fmt.Errorf("bad type %T", value),
)
return false
}
allow, _ := rl.Try()
return !allow
}
0707010000006C000081A4000000000000000000000001679A649F00000959000000000000000000000000000000000000002800000000dnsproxy-0.75.0/proxy/ratelimit_test.gopackage proxy
import (
"context"
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func TestRatelimitingProxy(t *testing.T) {
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
Ratelimit: 1,
})
// Start listening
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Create a DNS-over-UDP client connection
addr := dnsProxy.Addr(ProtoUDP)
client := &dns.Client{
Net: string(ProtoUDP),
Timeout: testTimeout,
}
// Send the first message (not blocked)
req := newTestMessage()
r, _, err := client.Exchange(req, addr.String())
if err != nil {
t.Fatalf("error in the first request: %s", err)
}
requireResponse(t, req, r)
// Send the second message (blocked)
req = newTestMessage()
_, _, err = client.Exchange(req, addr.String())
if err == nil {
t.Fatalf("second request was not blocked")
}
}
func TestRatelimiting(t *testing.T) {
// rate limit is 1 per sec
p := Proxy{}
p.Ratelimit = 1
addr := netip.MustParseAddr("127.0.0.1")
limited := p.isRatelimited(addr)
if limited {
t.Fatal("First request must have been allowed")
}
limited = p.isRatelimited(addr)
if !limited {
t.Fatal("Second request must have been ratelimited")
}
}
func TestWhitelist(t *testing.T) {
// rate limit is 1 per sec with whitelist
p := Proxy{}
p.Ratelimit = 1
p.RatelimitWhitelist = []netip.Addr{
netip.MustParseAddr("127.0.0.1"),
netip.MustParseAddr("127.0.0.2"),
netip.MustParseAddr("127.0.0.125"),
}
addr := netip.MustParseAddr("127.0.0.1")
limited := p.isRatelimited(addr)
if limited {
t.Fatal("First request must have been allowed")
}
limited = p.isRatelimited(addr)
if limited {
t.Fatal("Second request must have been allowed due to whitelist")
}
}
0707010000006D000081A4000000000000000000000001679A649F000009A9000000000000000000000000000000000000002B00000000dnsproxy-0.75.0/proxy/recursiondetector.gopackage proxy
import (
"encoding/binary"
"time"
glcache "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// uint* sizes in bytes to improve readability.
//
// TODO(e.burkov): Remove when there will be a more regardful way to define
// those. See https://github.com/golang/go/issues/29982.
const (
uint16sz = 2
uint64sz = 8
)
// TODO(e.burkov): Consider making configurable.
const (
// recursionTTL is the time recursive request is cached for.
recursionTTL = 1 * time.Second
// cachedRecurrentReqNum is the maximum number of cached recurrent requests.
cachedRecurrentReqNum = 1000
)
// recursionDetector detects recursion in DNS forwarding.
type recursionDetector struct {
recentRequests glcache.Cache
ttl time.Duration
}
// check checks if the passed req was already sent by the server.
func (rd *recursionDetector) check(msg *dns.Msg) (ok bool) {
if len(msg.Question) == 0 {
return false
}
key := msgToSignature(msg)
expireData := rd.recentRequests.Get(key)
if expireData == nil {
return false
}
expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))
return time.Now().Before(expire)
}
// add caches the msg if it has anything in the questions section.
func (rd *recursionDetector) add(msg *dns.Msg) {
now := time.Now()
if len(msg.Question) == 0 {
return
}
key := msgToSignature(msg)
expire64 := uint64(now.Add(rd.ttl).UnixNano())
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, expire64)
rd.recentRequests.Set(key, expire)
}
// clear clears the recent requests cache.
func (rd *recursionDetector) clear() {
rd.recentRequests.Clear()
}
// newRecursionDetector returns the initialized *recursionDetector.
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
return &recursionDetector{
recentRequests: glcache.New(glcache.Config{
EnableLRU: true,
MaxCount: suspectsNum,
}),
ttl: ttl,
}
}
// msgToSignature converts msg into it's signature represented in bytes.
func msgToSignature(msg *dns.Msg) (sig []byte) {
sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
// The binary.BigEndian byte order is used everywhere except when the real
// machine's endianness is needed.
byteOrder := binary.BigEndian
byteOrder.PutUint16(sig[0:], msg.Id)
q := msg.Question[0]
byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
copy(sig[2*uint16sz:], []byte(q.Name))
return sig
}
0707010000006E000081A4000000000000000000000001679A649F00000EDC000000000000000000000000000000000000003900000000dnsproxy-0.75.0/proxy/recursiondetector_internal_test.gopackage proxy
import (
"bytes"
"encoding/binary"
"log/slog"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestRecursionDetector_Check(t *testing.T) {
rd := newRecursionDetector(0, 2)
const (
recID = 1234
recTTL = time.Hour * 1
)
const nonRecID = recID * 2
sampleQuestion := dns.Question{
Name: "some.domain",
Qtype: dns.TypeAAAA,
}
sampleMsg := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: recID,
},
Question: []dns.Question{sampleQuestion},
}
// Manually add the message with big ttl.
key := msgToSignature(sampleMsg)
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
rd.recentRequests.Set(key, expire)
// Add an expired message.
sampleMsg.Id = nonRecID
rd.add(sampleMsg)
testCases := []struct {
name string
questions []dns.Question
id uint16
want bool
}{{
name: "recurrent",
questions: []dns.Question{sampleQuestion},
id: recID,
want: true,
}, {
name: "not_suspected",
questions: []dns.Question{sampleQuestion},
id: recID + 1,
want: false,
}, {
name: "expired",
questions: []dns.Question{sampleQuestion},
id: nonRecID,
want: false,
}, {
name: "empty",
questions: []dns.Question{},
id: nonRecID,
want: false,
}}
for _, tc := range testCases {
sampleMsg.Id = tc.id
sampleMsg.Question = tc.questions
t.Run(tc.name, func(t *testing.T) {
detected := rd.check(sampleMsg)
assert.Equal(t, tc.want, detected)
})
}
}
func TestRecursionDetector_Suspect(t *testing.T) {
rd := newRecursionDetector(0, 1)
testCases := []struct {
msg *dns.Msg
name string
want int
}{{
msg: &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: "some.domain",
Qtype: dns.TypeA,
}},
},
name: "simple",
want: 1,
}, {
msg: &dns.Msg{},
name: "unencumbered",
want: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(rd.clear)
rd.add(tc.msg)
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
})
}
}
// byteSink is a typed sink for benchmark results.
var byteSink []byte
func BenchmarkMsgToSignature(b *testing.B) {
const name = "some.not.very.long.host.name"
msg := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: name,
Qtype: dns.TypeAAAA,
}},
}
b.Run("efficient", func(b *testing.B) {
b.ReportAllocs()
for range b.N {
byteSink = msgToSignature(msg)
}
assert.NotEmpty(b, byteSink)
})
b.Run("inefficient", func(b *testing.B) {
b.ReportAllocs()
for range b.N {
byteSink = msgToSignatureSlow(msg)
}
assert.NotEmpty(b, byteSink)
})
// goos: darwin
// goarch: amd64
// pkg: github.com/AdguardTeam/dnsproxy/proxy
// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
// BenchmarkMsgToSignature/efficient-12 17155314 68.84 ns/op 288 B/op 1 allocs/op
// BenchmarkMsgToSignature/inefficient-12 460803 2367 ns/op 648 B/op 6 allocs/op
}
// msgToSignatureSlow converts msg into it's signature represented in bytes in
// the less efficient way.
//
// See BenchmarkMsgToSignature.
func msgToSignatureSlow(msg *dns.Msg) (sig []byte) {
type msgSignature struct {
name [netutil.MaxDomainNameLen]byte
id uint16
qtype uint16
}
b := bytes.NewBuffer(sig)
q := msg.Question[0]
signature := msgSignature{
id: msg.Id,
qtype: q.Qtype,
}
copy(signature.name[:], q.Name)
if err := binary.Write(b, binary.BigEndian, signature); err != nil {
slog.Default().Debug("writing message signature", slogutil.KeyError, err)
}
return b.Bytes()
}
0707010000006F000081A4000000000000000000000001679A649F00001AE0000000000000000000000000000000000000002000000000dnsproxy-0.75.0/proxy/server.gopackage proxy
import (
"context"
"fmt"
"io"
"log/slog"
"net"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
// configureListeners configures listeners.
func (p *Proxy) configureListeners(ctx context.Context) (err error) {
err = p.createUDPListeners(ctx)
if err != nil {
return err
}
err = p.createTCPListeners(ctx)
if err != nil {
return err
}
err = p.createTLSListeners()
if err != nil {
return err
}
err = p.createHTTPSListeners()
if err != nil {
return err
}
err = p.createQUICListeners()
if err != nil {
return err
}
err = p.createDNSCryptListeners()
if err != nil {
return err
}
return nil
}
// startListeners starts listener loops.
func (p *Proxy) startListeners() {
for _, l := range p.udpListen {
go p.udpPacketLoop(l, p.requestsSema)
}
for _, l := range p.tcpListen {
go p.tcpPacketLoop(l, ProtoTCP, p.requestsSema)
}
for _, l := range p.tlsListen {
go p.tcpPacketLoop(l, ProtoTLS, p.requestsSema)
}
for _, l := range p.httpsListen {
go func(l net.Listener) { _ = p.httpsServer.Serve(l) }(l)
}
for _, l := range p.h3Listen {
go func(l *quic.EarlyListener) { _ = p.h3Server.ServeListener(l) }(l)
}
for _, l := range p.quicListen {
go p.quicPacketLoop(l, p.requestsSema)
}
for _, l := range p.dnsCryptUDPListen {
go func(l *net.UDPConn) { _ = p.dnsCryptServer.ServeUDP(l) }(l)
}
for _, l := range p.dnsCryptTCPListen {
go func(l net.Listener) { _ = p.dnsCryptServer.ServeTCP(l) }(l)
}
}
// handleDNSRequest processes the context. The only error it returns is the one
// from the [RequestHandler], or [Resolve] if the [RequestHandler] is not set.
// d is left without a response as the documentation to [BeforeRequestHandler]
// says, and if it's ratelimited.
func (p *Proxy) handleDNSRequest(d *DNSContext) (err error) {
p.logDNSMessage(d.Req)
if d.Req.Response {
p.logger.Debug("dropping incoming response packet", "addr", d.Addr)
return nil
}
ip := d.Addr.Addr()
d.IsPrivateClient = p.privateNets.Contains(ip)
if !p.handleBefore(d) {
return nil
}
// ratelimit based on IP only, protects CPU cycles and outbound connections
//
// TODO(e.burkov): Investigate if written above true and move to UDP server
// implementation?
if d.Proto == ProtoUDP && p.isRatelimited(ip) {
p.logger.Debug("ratelimited based on ip only", "addr", d.Addr)
// Don't reply to ratelimited clients.
return nil
}
d.Res = p.validateRequest(d)
if d.Res == nil {
if p.RequestHandler != nil {
err = errors.Annotate(p.RequestHandler(p, d), "using request handler: %w")
} else {
err = errors.Annotate(p.Resolve(d), "using default request handler: %w")
}
}
p.logDNSMessage(d.Res)
p.respond(d)
return err
}
// validateRequest returns a response for invalid request or nil if the request
// is ok.
func (p *Proxy) validateRequest(d *DNSContext) (resp *dns.Msg) {
switch {
case len(d.Req.Question) != 1:
p.logger.Debug("invalid number of questions", "req_questions_len", len(d.Req.Question))
// TODO(e.burkov): Probably, FORMERR would be a better choice here.
// Check out RFC.
return p.messages.NewMsgSERVFAIL(d.Req)
case p.RefuseAny && d.Req.Question[0].Qtype == dns.TypeANY:
// Refuse requests of type ANY (anti-DDOS measure).
p.logger.Debug("refusing dns type any request")
return p.messages.NewMsgNOTIMPLEMENTED(d.Req)
case p.recDetector.check(d.Req):
p.logger.Debug("recursion detected", "req_question", d.Req.Question[0].Name)
return p.messages.NewMsgNXDOMAIN(d.Req)
case d.isForbiddenARPA(p.privateNets, p.logger):
p.logger.Debug(
"private arpa domain is requested",
"addr", d.Addr,
"arpa", d.Req.Question[0].Name,
)
return p.messages.NewMsgNXDOMAIN(d.Req)
default:
return nil
}
}
// isForbiddenARPA returns true if dctx contains a PTR, SOA, or NS request for
// some private address and client's address is not within the private network.
// Otherwise, it sets [DNSContext.RequestedPrivateRDNS] for future use.
func (dctx *DNSContext) isForbiddenARPA(privateNets netutil.SubnetSet, l *slog.Logger) (ok bool) {
q := dctx.Req.Question[0]
switch q.Qtype {
case dns.TypePTR, dns.TypeSOA, dns.TypeNS:
// Go on.
//
// TODO(e.burkov): Reconsider the list of types involved to private
// address space. Perhaps, use the logic for any type. See
// https://www.rfc-editor.org/rfc/rfc6761.html#section-6.1.
default:
return false
}
requestedPref, err := netutil.ExtractReversedAddr(q.Name)
if err != nil {
l.Debug("parsing reversed subnet", slogutil.KeyError, err)
return false
}
if privateNets.Contains(requestedPref.Addr()) {
dctx.RequestedPrivateRDNS = requestedPref
return !dctx.IsPrivateClient
}
return false
}
// respond writes the specified response to the client (or does nothing if d.Res is empty)
func (p *Proxy) respond(d *DNSContext) {
// d.Conn can be nil in the case of a DoH request.
if d.Conn != nil {
_ = d.Conn.SetWriteDeadline(time.Now().Add(defaultTimeout))
}
var err error
switch d.Proto {
case ProtoUDP:
err = p.respondUDP(d)
case ProtoTCP:
err = p.respondTCP(d)
case ProtoTLS:
err = p.respondTCP(d)
case ProtoHTTPS:
err = p.respondHTTPS(d)
case ProtoQUIC:
err = p.respondQUIC(d)
case ProtoDNSCrypt:
err = p.respondDNSCrypt(d)
default:
err = fmt.Errorf("SHOULD NOT HAPPEN - unknown protocol: %s", d.Proto)
}
if err != nil {
logWithNonCrit(err, "responding request", d.Proto, p.logger)
}
}
// Set TTL value of all records according to our settings
func (p *Proxy) setMinMaxTTL(r *dns.Msg) {
for _, rr := range r.Answer {
originalTTL := rr.Header().Ttl
newTTL := respectTTLOverrides(originalTTL, p.CacheMinTTL, p.CacheMaxTTL)
if originalTTL != newTTL {
p.logger.Debug("ttl overwritten", "old", originalTTL, "new", newTTL)
rr.Header().Ttl = newTTL
}
}
}
// logDNSMessage logs the given DNS message.
func (p *Proxy) logDNSMessage(m *dns.Msg) {
if m == nil {
return
}
var msg string
if m.Response {
msg = "out"
} else {
msg = "in"
}
slogutil.PrintLines(context.TODO(), p.logger, slog.LevelDebug, msg, m.String())
}
// logWithNonCrit logs the error on the appropriate level depending on whether
// err is a critical error or not.
func logWithNonCrit(err error, msg string, proto Proto, l *slog.Logger) {
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || isEPIPE(err) {
l.Debug(
"connection is closed",
"proto", proto,
"details", msg,
slogutil.KeyError, err,
)
} else if netErr := net.Error(nil); errors.As(err, &netErr) && netErr.Timeout() {
l.Debug(
"connection timed out",
"proto", proto,
"details", msg,
slogutil.KeyError, err,
)
} else {
l.Error(msg, "proto", proto, slogutil.KeyError, err)
}
}
07070100000070000081A4000000000000000000000001679A649F00000A9B000000000000000000000000000000000000002900000000dnsproxy-0.75.0/proxy/server_dnscrypt.gopackage proxy
import (
"context"
"fmt"
"net"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/syncutil"
"github.com/ameshkov/dnscrypt/v2"
"github.com/miekg/dns"
)
func (p *Proxy) createDNSCryptListeners() (err error) {
if len(p.DNSCryptUDPListenAddr) == 0 && len(p.DNSCryptTCPListenAddr) == 0 {
// Do nothing if DNSCrypt listen addresses are not specified.
return nil
}
if p.DNSCryptResolverCert == nil || p.DNSCryptProviderName == "" {
return errors.Error("invalid dnscrypt configuration: no certificate or provider name")
}
p.logger.Info("initializing dnscrypt", "provider", p.DNSCryptProviderName)
p.dnsCryptServer = &dnscrypt.Server{
ProviderName: p.DNSCryptProviderName,
ResolverCert: p.DNSCryptResolverCert,
Handler: &dnsCryptHandler{
proxy: p,
reqSema: p.requestsSema,
},
}
for _, a := range p.DNSCryptUDPListenAddr {
p.logger.Info("creating dnscrypt udp listener")
udpListen, lErr := net.ListenUDP(bootstrap.NetworkUDP, a)
if lErr != nil {
return fmt.Errorf("listening to dnscrypt udp socket: %w", lErr)
}
p.dnsCryptUDPListen = append(p.dnsCryptUDPListen, udpListen)
p.logger.Info("listening for dnscrypt messages on udp", "addr", udpListen.LocalAddr())
}
for _, a := range p.DNSCryptTCPListenAddr {
p.logger.Info("creating a dnscrypt tcp listener")
tcpListen, lErr := net.ListenTCP(bootstrap.NetworkTCP, a)
if lErr != nil {
return fmt.Errorf("listening to dnscrypt tcp socket: %w", lErr)
}
p.dnsCryptTCPListen = append(p.dnsCryptTCPListen, tcpListen)
p.logger.Info("listening for dnscrypt messages on tcp", "addr", tcpListen.Addr())
}
return nil
}
// dnsCryptHandler - dnscrypt.Handler implementation
type dnsCryptHandler struct {
proxy *Proxy
reqSema syncutil.Semaphore
}
// compile-time type check
var _ dnscrypt.Handler = &dnsCryptHandler{}
// ServeDNS - processes the DNS query
func (h *dnsCryptHandler) ServeDNS(rw dnscrypt.ResponseWriter, req *dns.Msg) (err error) {
d := h.proxy.newDNSContext(ProtoDNSCrypt, req, netutil.NetAddrToAddrPort(rw.RemoteAddr()))
d.DNSCryptResponseWriter = rw
// TODO(d.kolyshev): Pass and use context from above.
err = h.reqSema.Acquire(context.Background())
if err != nil {
return fmt.Errorf("dnsproxy: dnscrypt: acquiring semaphore: %w", err)
}
defer h.reqSema.Release()
return h.proxy.handleDNSRequest(d)
}
// Writes a response to the UDP client
func (p *Proxy) respondDNSCrypt(d *DNSContext) error {
if d.Res == nil {
// If no response has been written, do nothing and let it drop
return nil
}
return d.DNSCryptResponseWriter.WriteMsg(d.Res)
}
07070100000071000081A4000000000000000000000001679A649F00000A4C000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/proxy/server_dnscrypt_test.gopackage proxy
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/ameshkov/dnscrypt/v2"
"github.com/ameshkov/dnsstamps"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(d.kolyshev): Remove this after migrating dnscrypt to slog.
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
func getFreePort() uint {
l, _ := net.Listen("tcp", "127.0.0.1:0")
port := uint(l.Addr().(*net.TCPAddr).Port)
// stop listening immediately
_ = l.Close()
// sleep for 100ms (may be necessary on Windows)
time.Sleep(100 * time.Millisecond)
return port
}
func createTestDNSCryptProxy(t *testing.T) (*Proxy, dnscrypt.ResolverConfig) {
rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
assert.NoError(t, err)
cert, err := rc.CreateCert()
assert.NoError(t, err)
port := getFreePort()
p := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
DNSCryptUDPListenAddr: []*net.UDPAddr{{
Port: int(port), IP: net.ParseIP(listenIP),
}},
DNSCryptTCPListenAddr: []*net.TCPAddr{{
Port: int(port), IP: net.ParseIP(listenIP),
}},
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
EnableEDNSClientSubnet: true,
CacheEnabled: true,
CacheMinTTL: 20,
CacheMaxTTL: 40,
DNSCryptProviderName: rc.ProviderName,
DNSCryptResolverCert: cert,
})
return p, rc
}
func TestDNSCryptProxy(t *testing.T) {
// Prepare the proxy server
dnsProxy, rc := createTestDNSCryptProxy(t)
// Start listening
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Generate a DNS stamp
addr := fmt.Sprintf("%s:%d", listenIP, dnsProxy.Addr(ProtoDNSCrypt).(*net.UDPAddr).Port)
stamp, err := rc.CreateStamp(addr)
assert.Nil(t, err)
// Test DNSCrypt proxy on both UDP and TCP
checkDNSCryptProxy(t, "udp", stamp)
checkDNSCryptProxy(t, "tcp", stamp)
}
func checkDNSCryptProxy(t *testing.T, proto string, stamp dnsstamps.ServerStamp) {
// Create a DNSCrypt client
c := &dnscrypt.Client{
Timeout: defaultTimeout,
Net: proto,
}
// Fetch the server certificate
ri, err := c.DialStamp(stamp)
assert.Nil(t, err)
// Send the test message
msg := newTestMessage()
reply, err := c.Exchange(msg, ri)
assert.Nil(t, err)
requireResponse(t, msg, reply)
}
07070100000072000081A4000000000000000000000001679A649F000021C5000000000000000000000000000000000000002600000000dnsproxy-0.75.0/proxy/server_https.gopackage proxy
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/netip"
"net/url"
"strings"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"golang.org/x/net/http2"
)
// listenHTTP creates instances of TLS listeners that will be used to run an
// H1/H2 server. Returns the address the listener actually listens to (useful
// in the case if port 0 is specified).
func (p *Proxy) listenHTTP(addr *net.TCPAddr) (laddr *net.TCPAddr, err error) {
tcpListen, err := net.ListenTCP(bootstrap.NetworkTCP, addr)
if err != nil {
return nil, fmt.Errorf("tcp listener: %w", err)
}
p.logger.Info("listening to https", "addr", tcpListen.Addr())
tlsConfig := p.TLSConfig.Clone()
tlsConfig.NextProtos = []string{http2.NextProtoTLS, "http/1.1"}
tlsListen := tls.NewListener(tcpListen, tlsConfig)
p.httpsListen = append(p.httpsListen, tlsListen)
return tcpListen.Addr().(*net.TCPAddr), nil
}
// listenH3 creates instances of QUIC listeners that will be used for running
// an HTTP/3 server.
func (p *Proxy) listenH3(addr *net.UDPAddr) (err error) {
tlsConfig := p.TLSConfig.Clone()
tlsConfig.NextProtos = []string{"h3"}
quicListen, err := quic.ListenAddrEarly(addr.String(), tlsConfig, newServerQUICConfig())
if err != nil {
return fmt.Errorf("quic listener: %w", err)
}
p.logger.Info("listening to h3", "addr", quicListen.Addr())
p.h3Listen = append(p.h3Listen, quicListen)
return nil
}
// createHTTPSListeners creates TCP/UDP listeners and HTTP/H3 servers.
func (p *Proxy) createHTTPSListeners() (err error) {
p.httpsServer = &http.Server{
Handler: p,
ReadHeaderTimeout: defaultTimeout,
WriteTimeout: defaultTimeout,
}
if p.HTTP3 {
p.h3Server = &http3.Server{
Handler: p,
}
}
for _, addr := range p.HTTPSListenAddr {
p.logger.Info("creating an https server")
tcpAddr, lErr := p.listenHTTP(addr)
if lErr != nil {
return fmt.Errorf("failed to start HTTPS server on %s: %w", addr, lErr)
}
if p.HTTP3 {
// HTTP/3 server listens to the same pair IP:port as the one HTTP/2
// server listens to.
udpAddr := &net.UDPAddr{IP: tcpAddr.IP, Port: tcpAddr.Port}
err = p.listenH3(udpAddr)
if err != nil {
return fmt.Errorf("failed to start HTTP/3 server on %s: %w", udpAddr, err)
}
}
}
return nil
}
// newDoHReq returns new DNS request parsed from the given HTTP request. In
// case of invalid request returns nil and the suitable status code for an HTTP
// error response. l must not be nil.
func newDoHReq(r *http.Request, l *slog.Logger) (req *dns.Msg, statusCode int) {
var buf []byte
var err error
switch r.Method {
case http.MethodGet:
dnsParam := r.URL.Query().Get("dns")
buf, err = base64.RawURLEncoding.DecodeString(dnsParam)
if len(buf) == 0 || err != nil {
l.Debug(
"parsing dns request from http get param",
"param_name", dnsParam,
slogutil.KeyError, err,
)
return nil, http.StatusBadRequest
}
case http.MethodPost:
contentType := r.Header.Get(httphdr.ContentType)
if contentType != "application/dns-message" {
l.Debug("unsupported media type", "content_type", contentType)
return nil, http.StatusUnsupportedMediaType
}
// TODO(d.kolyshev): Limit reader.
buf, err = io.ReadAll(r.Body)
if err != nil {
l.Debug("reading http request body", slogutil.KeyError, err)
return nil, http.StatusBadRequest
}
defer slogutil.CloseAndLog(context.TODO(), l, r.Body, slog.LevelDebug)
default:
l.Debug("bad http method", "method", r.Method)
return nil, http.StatusMethodNotAllowed
}
req = &dns.Msg{}
if err = req.Unpack(buf); err != nil {
l.Debug("unpacking http msg", slogutil.KeyError, err)
return nil, http.StatusBadRequest
}
return req, http.StatusOK
}
// ServeHTTP is the http.Handler implementation that handles DoH queries.
//
// Here is what it returns:
//
// - http.StatusBadRequest if there is no DNS request data,
// - http.StatusUnsupportedMediaType if request content type is not
// "application/dns-message",
// - http.StatusMethodNotAllowed if request method is not GET or POST.
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.logger.Debug("incoming https request", "url", r.URL)
raddr, prx, err := remoteAddr(r, p.logger)
if err != nil {
p.logger.Debug("getting real ip", slogutil.KeyError, err)
}
if !p.checkBasicAuth(w, r, raddr) {
return
}
req, statusCode := newDoHReq(r, p.logger)
if req == nil {
http.Error(w, http.StatusText(statusCode), statusCode)
return
}
if prx.IsValid() {
p.logger.Debug("request came from proxy server", "addr", prx)
if !p.TrustedProxies.Contains(prx.Addr()) {
p.logger.Debug("proxy is not trusted, using original remote addr", "addr", prx)
// So the address of the proxy itself is used, as the remote address
// parsed from headers cannot be trusted.
//
// TODO(e.burkov): Do not parse headers in this case.
raddr = prx
}
}
d := p.newDNSContext(ProtoHTTPS, req, raddr)
d.HTTPRequest = r
d.HTTPResponseWriter = w
err = p.handleDNSRequest(d)
if err != nil {
p.logger.Debug("handling dns request", "proto", d.Proto, slogutil.KeyError, err)
}
}
// checkBasicAuth checks the basic authorization data, if necessary, and if the
// data isn't valid, it writes an error. shouldHandle is false if the request
// has been denied.
func (p *Proxy) checkBasicAuth(
w http.ResponseWriter,
r *http.Request,
raddr netip.AddrPort,
) (shouldHandle bool) {
ui := p.Config.Userinfo
if ui == nil {
return true
}
user, pass, _ := r.BasicAuth()
if matchesUserinfo(ui, user, pass) {
return true
}
p.logger.Error("basic auth failed", "user", user, "raddr", raddr)
h := w.Header()
h.Set(httphdr.WWWAuthenticate, `Basic realm="DNS", charset="UTF-8"`)
http.Error(w, "Authorization required", http.StatusUnauthorized)
return false
}
// matchesUserinfo returns false if user and pass don't match userinfo.
// userinfo must not be nil.
func matchesUserinfo(userinfo *url.Userinfo, user, pass string) (ok bool) {
requiredPassword, _ := userinfo.Password()
return user == userinfo.Username() && pass == requiredPassword
}
// Writes a response to the DoH client.
func (p *Proxy) respondHTTPS(d *DNSContext) (err error) {
resp := d.Res
w := d.HTTPResponseWriter
if resp == nil {
// Indicate the response's absence via a http.StatusInternalServerError.
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return nil
}
bytes, err := resp.Pack()
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return fmt.Errorf("packing message: %w", err)
}
if srvName := p.Config.HTTPSServerName; srvName != "" {
w.Header().Set(httphdr.Server, srvName)
}
w.Header().Set(httphdr.ContentType, "application/dns-message")
_, err = w.Write(bytes)
return err
}
// realIPFromHdrs extracts the actual client's IP address from the first
// suitable r's header. It returns an error if r doesn't contain any
// information about real client's IP address. Current headers priority is:
//
// 1. [httphdr.CFConnectingIP]
// 2. [httphdr.TrueClientIP]
// 3. [httphdr.XRealIP]
// 4. [httphdr.XForwardedFor]
func realIPFromHdrs(r *http.Request) (realIP netip.Addr, err error) {
for _, h := range []string{
httphdr.CFConnectingIP,
httphdr.TrueClientIP,
httphdr.XRealIP,
} {
realIP, err = netip.ParseAddr(strings.TrimSpace(r.Header.Get(h)))
if err == nil {
return realIP, nil
}
}
xff := r.Header.Get(httphdr.XForwardedFor)
firstComma := strings.IndexByte(xff, ',')
if firstComma > 0 {
xff = xff[:firstComma]
}
return netip.ParseAddr(strings.TrimSpace(xff))
}
// remoteAddr returns the real client's address and the IP address of the latest
// proxy server if any.
func remoteAddr(r *http.Request, l *slog.Logger) (addr, prx netip.AddrPort, err error) {
host, err := netip.ParseAddrPort(r.RemoteAddr)
if err != nil {
return netip.AddrPort{}, netip.AddrPort{}, err
}
realIP, err := realIPFromHdrs(r)
if err != nil {
l.Debug("getting ip address from http request", slogutil.KeyError, err)
return host, netip.AddrPort{}, nil
}
l.Debug("using ip address from http request", "addr", realIP)
// TODO(a.garipov): Add port if we can get it from headers like X-Real-Port,
// X-Forwarded-Port, etc.
addr = netip.AddrPortFrom(realIP, 0)
return addr, host, nil
}
07070100000073000081A4000000000000000000000001679A649F00002F4C000000000000000000000000000000000000002B00000000dnsproxy-0.75.0/proxy/server_https_test.gopackage proxy
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"strings"
"testing"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHttpsProxy(t *testing.T) {
testCases := []struct {
name string
http3 bool
}{{
name: "https_proxy",
http3: false,
}, {
name: "h3_proxy",
http3: true,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tlsConf, caPem := newTLSConfig(t)
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
TLSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
HTTPSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
QUICListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TLSConfig: tlsConf,
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
HTTP3: tc.http3,
})
// Run the proxy.
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
// Create the HTTP client that we'll be using for this test.
client := createTestHTTPClient(dnsProxy, caPem, tc.http3)
// Prepare a test message to be sent to the server.
msg := newTestMessage()
// Send the test message and check if the response is what we
// expected.
resp := sendTestDoHMessage(t, client, msg, nil)
requireResponse(t, msg, resp)
})
}
}
func TestProxy_trustedProxies(t *testing.T) {
var (
clientAddr = netip.MustParseAddr("1.2.3.4")
proxyAddr = netip.MustParseAddr("127.0.0.1")
)
doRequest := func(t *testing.T, addr, expectedClientIP netip.Addr) {
// Prepare the proxy server.
tlsConf, caPem := newTLSConfig(t)
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
TLSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
HTTPSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
QUICListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TLSConfig: tlsConf,
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
})
var gotAddr netip.Addr
dnsProxy.RequestHandler = func(_ *Proxy, d *DNSContext) (err error) {
gotAddr = d.Addr.Addr()
return dnsProxy.Resolve(d)
}
client := createTestHTTPClient(dnsProxy, caPem, false)
msg := newTestMessage()
dnsProxy.TrustedProxies = netip.PrefixFrom(addr, addr.BitLen())
// Start listening.
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
hdrs := map[string]string{
"X-Forwarded-For": strings.Join([]string{clientAddr.String(), proxyAddr.String()}, ","),
}
resp := sendTestDoHMessage(t, client, msg, hdrs)
requireResponse(t, msg, resp)
require.Equal(t, expectedClientIP, gotAddr)
}
t.Run("success", func(t *testing.T) {
doRequest(t, proxyAddr, clientAddr)
})
t.Run("not_in_trusted", func(t *testing.T) {
doRequest(t, netip.MustParseAddr("127.0.0.2"), proxyAddr)
})
}
func TestAddrsFromRequest(t *testing.T) {
var (
theIP = netip.AddrFrom4([4]byte{1, 2, 3, 4})
anotherIP = netip.AddrFrom4([4]byte{1, 2, 3, 5})
theIPStr = theIP.String()
anotherIPStr = anotherIP.String()
)
testCases := []struct {
name string
hdrs map[string]string
wantIP netip.Addr
wantErr string
}{{
name: "cf-connecting-ip",
hdrs: map[string]string{
"CF-Connecting-IP": theIPStr,
},
wantIP: theIP,
wantErr: "",
}, {
name: "true-client-ip",
hdrs: map[string]string{
"True-Client-IP": theIPStr,
},
wantIP: theIP,
wantErr: "",
}, {
name: "x-real-ip",
hdrs: map[string]string{
"X-Real-IP": theIPStr,
},
wantIP: theIP,
wantErr: "",
}, {
name: "no_any",
hdrs: map[string]string{
"CF-Connecting-IP": "invalid",
"True-Client-IP": "invalid",
"X-Real-IP": "invalid",
},
wantIP: netip.Addr{},
wantErr: `ParseAddr(""): unable to parse IP`,
}, {
name: "priority",
hdrs: map[string]string{
"X-Forwarded-For": strings.Join([]string{anotherIPStr, theIPStr}, ","),
"True-Client-IP": anotherIPStr,
"X-Real-IP": anotherIPStr,
"CF-Connecting-IP": theIPStr,
},
wantIP: theIP,
wantErr: "",
}, {
name: "x-forwarded-for_simple",
hdrs: map[string]string{
"X-Forwarded-For": strings.Join([]string{anotherIPStr, theIPStr}, ","),
},
wantIP: anotherIP,
wantErr: "",
}, {
name: "x-forwarded-for_single",
hdrs: map[string]string{
"X-Forwarded-For": theIPStr,
},
wantIP: theIP,
wantErr: "",
}, {
name: "x-forwarded-for_invalid_proxy",
hdrs: map[string]string{
"X-Forwarded-For": strings.Join([]string{theIPStr, "invalid"}, ","),
},
wantIP: theIP,
wantErr: "",
}, {
name: "x-forwarded-for_empty",
hdrs: map[string]string{
"X-Forwarded-For": "",
},
wantIP: netip.Addr{},
wantErr: `ParseAddr(""): unable to parse IP`,
}, {
name: "x-forwarded-for_redundant_spaces",
hdrs: map[string]string{
"X-Forwarded-For": " " + theIPStr + " ,\t" + anotherIPStr,
},
wantIP: theIP,
wantErr: "",
}, {
name: "cf-connecting-ip_redundant_spaces",
hdrs: map[string]string{
"CF-Connecting-IP": " " + theIPStr + "\t",
},
wantIP: theIP,
wantErr: "",
}}
for _, tc := range testCases {
r, err := http.NewRequest(http.MethodGet, "localhost", nil)
require.NoError(t, err)
for h, v := range tc.hdrs {
r.Header.Set(h, v)
}
t.Run(tc.name, func(t *testing.T) {
var ip netip.Addr
ip, err = realIPFromHdrs(r)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.wantIP, ip)
})
}
}
func TestRemoteAddr(t *testing.T) {
const thePort = 4321
var (
theIP = netip.AddrFrom4([4]byte{1, 2, 3, 4})
anotherIP = netip.AddrFrom4([4]byte{1, 2, 3, 5})
thirdIP = netip.AddrFrom4([4]byte{1, 2, 3, 6})
theIPStr = theIP.String()
anotherIPStr = anotherIP.String()
thirdIPStr = thirdIP.String()
)
rAddr := netip.AddrPortFrom(theIP, thePort)
testCases := []struct {
name string
remoteAddr string
hdrs map[string]string
wantErr string
wantIP netip.AddrPort
wantProxy netip.AddrPort
}{{
name: "no_proxy",
remoteAddr: rAddr.String(),
hdrs: nil,
wantErr: "",
wantIP: netip.AddrPortFrom(theIP, thePort),
wantProxy: netip.AddrPort{},
}, {
name: "proxied_with_cloudflare",
remoteAddr: rAddr.String(),
hdrs: map[string]string{
"CF-Connecting-IP": anotherIPStr,
},
wantErr: "",
wantIP: netip.AddrPortFrom(anotherIP, 0),
wantProxy: netip.AddrPortFrom(theIP, thePort),
}, {
name: "proxied_once",
remoteAddr: rAddr.String(),
hdrs: map[string]string{
"X-Forwarded-For": anotherIPStr,
},
wantErr: "",
wantIP: netip.AddrPortFrom(anotherIP, 0),
wantProxy: netip.AddrPortFrom(theIP, thePort),
}, {
name: "proxied_multiple",
remoteAddr: rAddr.String(),
hdrs: map[string]string{
"X-Forwarded-For": strings.Join([]string{anotherIPStr, thirdIPStr}, ","),
},
wantErr: "",
wantIP: netip.AddrPortFrom(anotherIP, 0),
wantProxy: netip.AddrPortFrom(theIP, thePort),
}, {
name: "no_port",
remoteAddr: theIPStr,
hdrs: nil,
wantErr: "not an ip:port",
wantIP: netip.AddrPort{},
wantProxy: netip.AddrPort{},
}, {
name: "bad_port",
remoteAddr: theIPStr + ":notport",
hdrs: nil,
wantErr: `invalid port "notport" parsing "1.2.3.4:notport"`,
wantIP: netip.AddrPort{},
wantProxy: netip.AddrPort{},
}, {
name: "bad_host",
remoteAddr: "host:1",
hdrs: nil,
wantErr: `ParseAddr("host"): unable to parse IP`,
wantIP: netip.AddrPort{},
wantProxy: netip.AddrPort{},
}, {
name: "bad_proxied_host",
remoteAddr: "host:1",
hdrs: map[string]string{
"CF-Connecting-IP": theIPStr,
},
wantErr: `ParseAddr("host"): unable to parse IP`,
wantIP: netip.AddrPort{},
wantProxy: netip.AddrPort{},
}}
l := slogutil.NewDiscardLogger()
for _, tc := range testCases {
r, err := http.NewRequest(http.MethodGet, "localhost", nil)
require.NoError(t, err)
r.RemoteAddr = tc.remoteAddr
for h, v := range tc.hdrs {
r.Header.Set(h, v)
}
t.Run(tc.name, func(t *testing.T) {
var addr, prx netip.AddrPort
addr, prx, err = remoteAddr(r, l)
if tc.wantErr != "" {
testutil.AssertErrorMsg(t, tc.wantErr, err)
return
}
require.NoError(t, err)
assert.Equal(t, tc.wantIP, addr)
assert.Equal(t, tc.wantProxy, prx)
})
}
}
// sendTestDoHMessage sends the specified DNS message using client and returns
// the DNS response.
func sendTestDoHMessage(
t *testing.T,
client *http.Client,
m *dns.Msg,
hdrs map[string]string,
) (resp *dns.Msg) {
packed, err := m.Pack()
require.NoError(t, err)
u := url.URL{
Scheme: "https",
Host: tlsServerName,
Path: "/dns-query",
RawQuery: fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(packed)),
}
method := http.MethodGet
if _, ok := client.Transport.(*http3.Transport); ok {
// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
method = http3.MethodGet0RTT
}
req, err := http.NewRequest(method, u.String(), nil)
require.NoError(t, err)
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
for k, v := range hdrs {
req.Header.Set(k, v)
}
httpResp, err := client.Do(req) // nolint:bodyclose
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, httpResp.Body.Close)
require.True(
t,
httpResp.ProtoAtLeast(2, 0),
"the proto is too old: %s",
httpResp.Proto,
)
body, err := io.ReadAll(httpResp.Body)
require.NoError(t, err)
resp = &dns.Msg{}
err = resp.Unpack(body)
require.NoError(t, err)
return resp
}
// createTestHTTPClient creates an *http.Client that will be used to send
// requests to the specified dnsProxy.
func createTestHTTPClient(dnsProxy *Proxy, caPem []byte, http3Enabled bool) (client *http.Client) {
// prepare roots list so that the server cert was successfully validated.
roots := x509.NewCertPool()
roots.AppendCertsFromPEM(caPem)
tlsClientConfig := &tls.Config{
ServerName: tlsServerName,
RootCAs: roots,
}
var transport http.RoundTripper
if http3Enabled {
tlsClientConfig.NextProtos = []string{"h3"}
transport = &http3.Transport{
Dial: func(
ctx context.Context,
_ string,
tlsCfg *tls.Config,
cfg *quic.Config,
) (quic.EarlyConnection, error) {
addr := dnsProxy.Addr(ProtoHTTPS).String()
return quic.DialAddrEarly(ctx, addr, tlsCfg, cfg)
},
TLSClientConfig: tlsClientConfig,
QUICConfig: &quic.Config{},
DisableCompression: true,
}
} else {
dialer := &net.Dialer{
Timeout: defaultTimeout,
}
dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Route request to the DNS-over-HTTPS server address.
return dialer.DialContext(ctx, network, dnsProxy.Addr(ProtoHTTPS).String())
}
tlsClientConfig.NextProtos = []string{"h2", "http/1.1"}
transport = &http.Transport{
TLSClientConfig: tlsClientConfig,
DisableCompression: true,
DialContext: dialContext,
ForceAttemptHTTP2: true,
}
}
return &http.Client{
Transport: transport,
Timeout: defaultTimeout,
}
}
07070100000074000081A4000000000000000000000001679A649F00003D1D000000000000000000000000000000000000002500000000dnsproxy-0.75.0/proxy/server_quic.gopackage proxy
import (
"context"
"encoding/binary"
"fmt"
"io"
"log/slog"
"math"
"net"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/syncutil"
"github.com/bluele/gcache"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
// NextProtoDQ is the ALPN token for DoQ. During connection establishment,
// DNS/QUIC support is indicated by selecting the ALPN token "dq" in the
// crypto handshake.
// DoQ RFC: https://www.rfc-editor.org/rfc/rfc9250.html
const NextProtoDQ = "doq"
// compatProtoDQ is a list of ALPN tokens used by a QUIC connection.
// NextProtoDQ is the latest draft version supported by dnsproxy, but it also
// includes previous drafts.
var compatProtoDQ = []string{NextProtoDQ, "doq-i02", "doq-i00", "dq"}
// maxQUICIdleTimeout is maximum QUIC idle timeout. The default value in
// quic-go is 30 seconds, but our internal tests show that a higher value works
// better for clients written with ngtcp2.
const maxQUICIdleTimeout = 5 * time.Minute
// quicAddrValidatorCacheSize is the size of the cache that we use in the QUIC
// address validator. The value is chosen arbitrarily and we should consider
// making it configurable.
// TODO(ameshkov): make it configurable.
const quicAddrValidatorCacheSize = 1000
// quicAddrValidatorCacheTTL is time-to-live for cache items in the QUIC address
// validator. The value is chosen arbitrarily and we should consider making it
// configurable.
// TODO(ameshkov): make it configurable.
const quicAddrValidatorCacheTTL = 30 * time.Minute
const (
// DoQCodeNoError is used when the connection or stream needs to be closed,
// but there is no error to signal.
DoQCodeNoError quic.ApplicationErrorCode = 0
// DoQCodeInternalError signals that the DoQ implementation encountered
// an internal error and is incapable of pursuing the transaction or the
// connection.
DoQCodeInternalError quic.ApplicationErrorCode = 1
// DoQCodeProtocolError signals that the DoQ implementation encountered
// a protocol error and is forcibly aborting the connection.
DoQCodeProtocolError quic.ApplicationErrorCode = 2
)
// createQUICListeners creates QUIC listeners for the DoQ server.
func (p *Proxy) createQUICListeners() error {
for _, a := range p.QUICListenAddr {
p.logger.Info("creating quic listener", "addr", a)
conn, err := net.ListenUDP(bootstrap.NetworkUDP, a)
if err != nil {
return fmt.Errorf("listening to %s: %w", a, err)
}
p.quicConns = append(p.quicConns, conn)
v := newQUICAddrValidator(quicAddrValidatorCacheSize, quicAddrValidatorCacheTTL)
transport := &quic.Transport{
Conn: conn,
VerifySourceAddress: v.requiresValidation,
}
tlsConfig := p.TLSConfig.Clone()
tlsConfig.NextProtos = compatProtoDQ
quicListen, err := transport.ListenEarly(
tlsConfig,
newServerQUICConfig(),
)
if err != nil {
return fmt.Errorf("quic listener: %w", err)
}
p.quicTransports = append(p.quicTransports, transport)
p.quicListen = append(p.quicListen, quicListen)
p.logger.Info("listening quic", "addr", quicListen.Addr())
}
return nil
}
// quicPacketLoop listens for incoming QUIC packets.
//
// See also the comment on Proxy.requestsSema.
func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, reqSema syncutil.Semaphore) {
p.logger.Info("entering dns-over-quic listener loop", "addr", l.Addr())
for {
ctx := context.Background()
conn, err := l.Accept(ctx)
if err != nil {
logQUICError(ctx, "accepting quic conn", err, p.logger)
break
}
err = reqSema.Acquire(ctx)
if err != nil {
p.logger.ErrorContext(
ctx,
"acquiring semaphore",
"proto", ProtoQUIC,
slogutil.KeyError, err,
)
break
}
go func() {
defer reqSema.Release()
p.handleQUICConnection(conn, reqSema)
}()
}
}
// logQUICError writes suitable log message for the given err.
func logQUICError(ctx context.Context, prefix string, err error, l *slog.Logger) {
if isQUICErrorForDebugLog(err) {
l.DebugContext(
ctx,
"closed or timed out",
slogutil.KeyPrefix, prefix,
slogutil.KeyError, err,
)
} else {
l.ErrorContext(ctx, prefix, slogutil.KeyError, err)
}
}
// handleQUICConnection handles a new QUIC connection. It waits for new streams
// and passes them to handleQUICStream.
//
// See also the comment on Proxy.requestsSema.
func (p *Proxy) handleQUICConnection(conn quic.Connection, reqSema syncutil.Semaphore) {
for {
ctx := context.Background()
// The stub to resolver DNS traffic follows a simple pattern in which
// the client sends a query, and the server provides a response. This
// design specifies that for each subsequent query on a QUIC connection
// the client MUST select the next available client-initiated
// bidirectional stream.
stream, err := conn.AcceptStream(ctx)
if err != nil {
logQUICError(ctx, "accepting quic stream", err, p.logger)
// Close the connection to make sure resources are freed.
closeQUICConn(conn, DoQCodeNoError, p.logger)
return
}
err = reqSema.Acquire(ctx)
if err != nil {
p.logger.ErrorContext(ctx, "acquiring semaphore", slogutil.KeyError, err)
// Close the connection to make sure resources are freed.
closeQUICConn(conn, DoQCodeNoError, p.logger)
return
}
go func() {
defer reqSema.Release()
p.handleQUICStream(ctx, stream, conn)
// The server MUST send the response(s) on the same stream and MUST
// indicate, after the last response, through the STREAM FIN
// mechanism that no further data will be sent on that stream.
_ = stream.Close()
}()
}
}
// handleQUICStream reads DNS queries from the stream, processes them,
// and writes back the response.
func (p *Proxy) handleQUICStream(ctx context.Context, stream quic.Stream, conn quic.Connection) {
bufPtr := p.bytesPool.Get().(*[]byte)
defer p.bytesPool.Put(bufPtr)
// One query - one stream.
// The client MUST select the next available client-initiated bidirectional
// stream for each subsequent query on a QUIC connection.
// err is not checked here because STREAM FIN sent by the client is
// indicated as error here. Instead, we should check the number of bytes
// received.
buf := *bufPtr
n, err := readAll(stream, buf)
// Note that io.EOF does not really mean that there's any error, this is
// just a signal that there will be no data to read anymore from this
// stream.
if (err != nil && err != io.EOF) || n < minDNSPacketSize {
logShortQUICRead(ctx, err, p.logger)
return
}
// In theory, we should use ALPN to get the DoQ version properly. However,
// since there are not too many versions now, we only check how the DNS
// query is encoded. If it's sent with a 2-byte prefix, we consider this a
// DoQ v1. Otherwise, a draft version.
doqVersion := DoQv1
req := &dns.Msg{}
// Note that we support both the old drafts and the new RFC. In the old
// draft DNS messages were not prefixed with the message length.
packetLen := binary.BigEndian.Uint16(buf[:2])
if packetLen == uint16(n-2) {
err = req.Unpack(buf[2:])
} else {
err = req.Unpack(buf)
doqVersion = DoQv1Draft
}
if err != nil {
p.logger.ErrorContext(ctx, "unpacking quic packet", slogutil.KeyError, err)
closeQUICConn(conn, DoQCodeProtocolError, p.logger)
return
}
if !validQUICMsg(req, p.logger) {
// If a peer encounters such an error condition, it is considered a
// fatal error. It SHOULD forcibly abort the connection using QUIC's
// CONNECTION_CLOSE mechanism and SHOULD use the DoQ error code
// DOQ_PROTOCOL_ERROR.
closeQUICConn(conn, DoQCodeProtocolError, p.logger)
return
}
d := p.newDNSContext(ProtoQUIC, req, netutil.NetAddrToAddrPort(conn.RemoteAddr()))
d.QUICStream = stream
d.QUICConnection = conn
d.DoQVersion = doqVersion
err = p.handleDNSRequest(d)
if err != nil {
p.logger.DebugContext(
ctx,
"error handling dns request",
"proto", d.Proto,
slogutil.KeyError, err,
)
}
}
// respondQUIC writes a response to the QUIC stream.
func (p *Proxy) respondQUIC(d *DNSContext) error {
resp := d.Res
if resp == nil {
// If no response has been written, close the QUIC connection now.
closeQUICConn(d.QUICConnection, DoQCodeInternalError, p.logger)
return errors.Error("no response to write")
}
bytes, err := resp.Pack()
if err != nil {
return fmt.Errorf("couldn't convert message into wire format: %w", err)
}
// Depending on the DoQ version with either write a 2-bytes prefixed message
// or just write the message (for old draft versions).
var buf []byte
switch d.DoQVersion {
case DoQv1:
buf = proxyutil.AddPrefix(bytes)
case DoQv1Draft:
buf = bytes
default:
return fmt.Errorf("invalid protocol version: %d", d.DoQVersion)
}
n, err := d.QUICStream.Write(buf)
if err != nil {
return fmt.Errorf("conn.Write(): %w", err)
}
if n != len(buf) {
return fmt.Errorf("conn.Write() returned with %d != %d", n, len(buf))
}
return nil
}
// validQUICMsg validates the incoming DNS message and returns false if
// something is wrong with the message.
func validQUICMsg(req *dns.Msg, l *slog.Logger) (ok bool) {
// See https://www.rfc-editor.org/rfc/rfc9250.html#name-protocol-errors
// 1. a client or server receives a message with a non-zero Message ID.
//
// We do consciously not validate this case since there are stub proxies
// that are sending a non-zero Message IDs.
// 2. a client or server receives a STREAM FIN before receiving all the
// bytes for a message indicated in the 2-octet length field.
// 3. a server receives more than one query on a stream
//
// These cases are covered earlier when unpacking the DNS message.
// 4. the client or server does not indicate the expected STREAM FIN after
// sending requests or responses (see Section 4.2).
//
// This is quite problematic to validate this case since this would imply
// we have to wait until STREAM FIN is arrived before we start processing
// the message. So we're consciously ignoring this case in this
// implementation.
// 5. an implementation receives a message containing the edns-tcp-keepalive
// EDNS(0) Option [RFC7828] (see Section 5.5.2).
if opt := req.IsEdns0(); opt != nil {
for _, option := range opt.Option {
// Check for EDNS TCP keepalive option
if option.Option() == dns.EDNS0TCPKEEPALIVE {
l.Debug("client sent edns0 tcp keepalive option")
return false
}
}
}
// 6. a client or a server attempts to open a unidirectional QUIC stream.
//
// This case can only be handled when writing a response.
// 7. a server receives a "replayable" transaction in 0-RTT data
//
// The information necessary to validate this is not exposed by quic-go.
return true
}
// logShortQUICRead is a logging helper for short reads from a QUIC stream.
func logShortQUICRead(ctx context.Context, err error, l *slog.Logger) {
if err == nil {
l.InfoContext(ctx, "quic packet too short for dns query")
return
}
logQUICError(ctx, "reading from quic stream", err, l)
}
const (
// qCodeNoError is returned when the QUIC connection was gracefully closed
// and there is no error to signal.
qCodeNoError = quic.ApplicationErrorCode(quic.NoError)
// qCodeApplicationErrorError is used for Initial and Handshake packets.
// This error is considered as non-critical and will not be logged as error.
qCodeApplicationErrorError = quic.ApplicationErrorCode(quic.ApplicationErrorErrorCode)
)
// isQUICErrorForDebugLog returns true if err is a non-critical error, most
// probably related to the current QUIC implementation. err must not be nil.
//
// TODO(ameshkov): re-test when updating quic-go.
func isQUICErrorForDebugLog(err error) (ok bool) {
if errors.Is(err, quic.ErrServerClosed) {
// This error is returned when the QUIC listener was closed by us. This
// is an expected error, we don't need the detailed logs here.
return true
}
var qAppErr *quic.ApplicationError
if errors.As(err, &qAppErr) &&
(qAppErr.ErrorCode == qCodeNoError || qAppErr.ErrorCode == qCodeApplicationErrorError) {
// No need to have detailed logs for these error codes either.
//
// TODO(a.garipov): Consider adding other error codes.
return true
}
if errors.Is(err, quic.Err0RTTRejected) {
// This error is returned on AcceptStream calls when the server rejects
// 0-RTT for some reason. This is a common scenario, no need for extra
// logs.
return true
}
// This error is returned when we're trying to accept a new stream from a
// connection that had no activity for over than the keep-alive timeout.
// This is a common scenario, no need for extra logs.
var qIdleErr *quic.IdleTimeoutError
return errors.As(err, &qIdleErr)
}
// closeQUICConn quietly closes the QUIC connection.
func closeQUICConn(conn quic.Connection, code quic.ApplicationErrorCode, l *slog.Logger) {
l.Debug("closing quic conn", "addr", conn.LocalAddr(), "code", code)
err := conn.CloseWithError(code, "")
if err != nil {
l.Debug("closing quic connection", "code", code, slogutil.KeyError, err)
}
}
// newServerQUICConfig creates *quic.Config populated with the default settings.
// This function is supposed to be used for both DoQ and DoH3 server.
func newServerQUICConfig() (conf *quic.Config) {
return &quic.Config{
MaxIdleTimeout: maxQUICIdleTimeout,
MaxIncomingStreams: math.MaxUint16,
MaxIncomingUniStreams: math.MaxUint16,
// Enable 0-RTT by default for all connections on the server-side.
Allow0RTT: true,
}
}
// quicAddrValidator is a helper struct that holds a small LRU cache of
// addresses for which we do not require address validation.
type quicAddrValidator struct {
cache gcache.Cache
ttl time.Duration
}
// newQUICAddrValidator initializes a new instance of *quicAddrValidator.
func newQUICAddrValidator(cacheSize int, ttl time.Duration) (v *quicAddrValidator) {
return &quicAddrValidator{
cache: gcache.New(cacheSize).LRU().Build(),
ttl: ttl,
}
}
// requiresValidation determines if a QUIC Retry packet should be sent by the
// client. This allows the server to verify the client's address but increases
// the latency.
func (v *quicAddrValidator) requiresValidation(addr net.Addr) (ok bool) {
// addr must be *net.UDPAddr here and if it's not we don't mind panic.
key := addr.(*net.UDPAddr).IP.String()
if v.cache.Has(key) {
return false
}
err := v.cache.SetWithExpire(key, true, v.ttl)
if err != nil {
// Shouldn't happen, since we don't set a serialization function.
panic(fmt.Errorf("quic validator: setting cache item: %w", err))
}
// Address not found in the cache so return true to make sure the server
// will require address validation.
return true
}
// readAll reads from r until an error or io.EOF into the specified buffer buf.
// A successful call returns err == nil, not err == io.EOF. If the buffer is
// too small, it returns error io.ErrShortBuffer. This function has some
// similarities to io.ReadAll, but it reads to the specified buffer and not
// allocates (and grows) a new one. Also, it is completely different from
// io.ReadFull as that one reads the exact number of bytes (buffer length) and
// readAll reads until io.EOF or until the buffer is filled.
func readAll(r io.Reader, buf []byte) (n int, err error) {
for {
if n == len(buf) {
return n, io.ErrShortBuffer
}
var read int
read, err = r.Read(buf[n:])
n += read
if err != nil {
if err == io.EOF {
err = nil
}
return n, err
}
}
}
07070100000075000081A4000000000000000000000001679A649F00001AB3000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/proxy/server_quic_test.gopackage proxy
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/stretchr/testify/require"
)
func TestQuicProxy(t *testing.T) {
serverConfig, caPem := newTLSConfig(t)
roots := x509.NewCertPool()
roots.AppendCertsFromPEM(caPem)
tlsConfig := &tls.Config{
ServerName: tlsServerName,
RootCAs: roots,
NextProtos: append([]string{NextProtoDQ}, compatProtoDQ...),
}
conf := &Config{
Logger: slogutil.NewDiscardLogger(),
QUICListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TLSConfig: serverConfig,
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
}
var addr *net.UDPAddr
t.Run("run", func(t *testing.T) {
dnsProxy := mustNew(t, conf)
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
addr = testutil.RequireTypeAssert[*net.UDPAddr](t, dnsProxy.Addr(ProtoQUIC))
conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, nil)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return conn.CloseWithError(DoQCodeNoError, "")
})
for range 10 {
sendTestQUICMessage(t, conn, DoQv1)
// Send a message encoded for a draft version as well.
sendTestQUICMessage(t, conn, DoQv1Draft)
}
})
require.False(t, t.Failed())
conf.QUICListenAddr = []*net.UDPAddr{addr}
conf.UpstreamConfig = newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr)
t.Run("rerun", func(t *testing.T) {
dnsProxy := mustNew(t, conf)
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, nil)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return conn.CloseWithError(DoQCodeNoError, "")
})
sendTestQUICMessage(t, conn, DoQv1)
// Send a message encoded for a draft version as well.
sendTestQUICMessage(t, conn, DoQv1Draft)
})
}
func TestQuicProxy_largePackets(t *testing.T) {
serverConfig, caPem := newTLSConfig(t)
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
TLSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
HTTPSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
QUICListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TLSConfig: serverConfig,
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
// Make sure the request does not go to any real upstream.
RequestHandler: func(_ *Proxy, d *DNSContext) (err error) {
resp := &dns.Msg{}
resp.SetReply(d.Req)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: d.Req.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.IP{8, 8, 8, 8},
}}
d.Res = resp
return nil
},
})
// Start listening.
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
roots := x509.NewCertPool()
roots.AppendCertsFromPEM(caPem)
tlsConfig := &tls.Config{
ServerName: tlsServerName,
RootCAs: roots,
NextProtos: append([]string{NextProtoDQ}, compatProtoDQ...),
}
// Create a DNS-over-QUIC client connection.
addr := dnsProxy.Addr(ProtoQUIC)
// Open a QUIC connection.
conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, nil)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return conn.CloseWithError(DoQCodeNoError, "")
})
// Create a test message large enough to take multiple QUIC frames.
msg := newTestMessage()
msg.Extra = []dns.RR{
&dns.OPT{
Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT, Class: 4096},
Option: []dns.EDNS0{
&dns.EDNS0_PADDING{Padding: make([]byte, 4096)},
},
},
}
resp := sendQUICMessage(t, msg, conn, DoQv1)
requireResponse(t, msg, resp)
}
// sendQUICMessage sends msg to the specified QUIC connection.
func sendQUICMessage(
t *testing.T,
msg *dns.Msg,
conn quic.Connection,
doqVersion DoQVersion,
) (resp *dns.Msg) {
// Open a new QUIC stream to write there a test DNS query.
stream, err := conn.OpenStreamSync(context.Background())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, stream.Close)
packedMsg, err := msg.Pack()
require.NoError(t, err)
buf := packedMsg
if doqVersion == DoQv1 {
buf = proxyutil.AddPrefix(packedMsg)
}
// Send the DNS query to the stream.
err = writeQUICStream(buf, stream)
require.NoError(t, err)
// Close closes the write-direction of the stream and sends
// a STREAM FIN packet.
_ = stream.Close()
// Now read the response from the stream.
respBytes := make([]byte, 64*1024)
n, err := stream.Read(respBytes)
if err != nil {
require.ErrorIs(t, err, io.EOF)
}
require.Greater(t, n, minDNSPacketSize)
// Unpack the DNS response.
resp = new(dns.Msg)
if doqVersion == DoQv1 {
err = resp.Unpack(respBytes[2:])
} else {
err = resp.Unpack(respBytes)
}
require.NoError(t, err)
return resp
}
// writeQUICStream writes buf to the specified QUIC stream in chunks. This way
// it is possible to test how the server deals with chunked DNS messages.
func writeQUICStream(buf []byte, stream quic.Stream) (err error) {
// Send the DNS query to the stream and split it into chunks of up
// to 400 bytes. 400 is an arbitrary chosen value.
chunkSize := 400
for i := 0; i < len(buf); i += chunkSize {
chunkStart := i
chunkEnd := i + chunkSize
if chunkEnd > len(buf) {
chunkEnd = len(buf)
}
_, err = stream.Write(buf[chunkStart:chunkEnd])
if err != nil {
return err
}
if len(buf) > chunkSize {
// Emulate network latency.
time.Sleep(time.Millisecond)
}
}
return nil
}
// sendTestQUICMessage send a test message to the specified QUIC connection.
func sendTestQUICMessage(t *testing.T, conn quic.Connection, doqVersion DoQVersion) {
msg := newTestMessage()
resp := sendQUICMessage(t, msg, conn, doqVersion)
requireResponse(t, msg, resp)
}
07070100000076000081A4000000000000000000000001679A649F000014B7000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/server_tcp.gopackage proxy
import (
"context"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"net"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/syncutil"
"github.com/miekg/dns"
)
func (p *Proxy) createTCPListeners(ctx context.Context) (err error) {
for _, a := range p.TCPListenAddr {
p.logger.Info("creating tcp server socket", "addr", a)
lsnr, lErr := proxynetutil.ListenConfig(p.logger).Listen(
ctx,
bootstrap.NetworkTCP,
a.String(),
)
if lErr != nil {
return fmt.Errorf("listening to tcp socket: %w", lErr)
}
tcpListener, ok := lsnr.(*net.TCPListener)
if !ok {
return fmt.Errorf("wrong listener type on tcp addr %s: %T", a, lsnr)
}
p.tcpListen = append(p.tcpListen, tcpListener)
p.logger.Info("listening to tcp", "addr", tcpListener.Addr())
}
return nil
}
func (p *Proxy) createTLSListeners() (err error) {
for _, a := range p.TLSListenAddr {
p.logger.Info("creating tls server socket", "addr", a)
var tcpListen *net.TCPListener
tcpListen, err = net.ListenTCP("tcp", a)
if err != nil {
return fmt.Errorf("listening on tls addr %s: %w", a, err)
}
l := tls.NewListener(tcpListen, p.TLSConfig)
p.tlsListen = append(p.tlsListen, l)
p.logger.Info("listening to tls", "addr", l.Addr())
}
return nil
}
// tcpPacketLoop listens for incoming TCP packets. proto must be either
// [ProtoTCP] or [ProtoTLS].
//
// See also the comment on Proxy.requestsSema.
func (p *Proxy) tcpPacketLoop(l net.Listener, proto Proto, reqSema syncutil.Semaphore) {
p.logger.Info("entering listener loop", "proto", proto, "addr", l.Addr())
for {
clientConn, err := l.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
p.logger.Debug("tcp connection closed", "addr", l.Addr())
} else {
p.logger.Error("reading from tcp", slogutil.KeyError, err)
}
break
}
// TODO(d.kolyshev): Pass and use context from above.
err = reqSema.Acquire(context.Background())
if err != nil {
p.logger.Error("acquiring semaphore", "proto", ProtoTCP, slogutil.KeyError, err)
break
}
go p.handleTCPConnection(clientConn, proto, reqSema)
}
}
// handleTCPConnection starts a loop that handles an incoming TCP connection.
// proto must be either [ProtoTCP] or [ProtoTLS].
func (p *Proxy) handleTCPConnection(conn net.Conn, proto Proto, reqSema syncutil.Semaphore) {
defer slogutil.RecoverAndLog(context.TODO(), p.logger)
defer reqSema.Release()
defer func() {
err := conn.Close()
if err != nil {
logWithNonCrit(err, "closing conn", ProtoTCP, p.logger)
}
}()
p.logger.Debug("handling new request", "proto", proto, "raddr", conn.RemoteAddr())
for p.isStarted() {
err := conn.SetDeadline(time.Now().Add(defaultTimeout))
if err != nil {
// Consider deadline errors non-critical.
logWithNonCrit(err, "setting deadline", ProtoTCP, p.logger)
}
req := p.readDNSReq(conn)
if req == nil {
return
}
d := p.newDNSContext(proto, req, netutil.NetAddrToAddrPort(conn.RemoteAddr()))
d.Conn = conn
err = p.handleDNSRequest(d)
if err != nil {
logWithNonCrit(err, "handling request", ProtoTCP, p.logger)
}
}
}
// readDNSReq returns DNS request message from the given connection or nil if
// it failed to read it. Properly logs the error if it happened.
func (p *Proxy) readDNSReq(conn net.Conn) (req *dns.Msg) {
packet, err := readPrefixed(conn)
if err != nil {
logWithNonCrit(err, "reading msg", ProtoTCP, p.logger)
return nil
}
req = &dns.Msg{}
err = req.Unpack(packet)
if err != nil {
p.logger.Error("handling tcp; unpacking msg", slogutil.KeyError, err)
return nil
}
return req
}
// errTooLarge means that a DNS message is larger than 64KiB.
const errTooLarge errors.Error = "dns message is too large"
// readPrefixed reads a DNS message with a 2-byte prefix containing message
// length from conn.
func readPrefixed(conn net.Conn) (b []byte, err error) {
l := make([]byte, 2)
_, err = conn.Read(l)
if err != nil {
return nil, fmt.Errorf("reading len: %w", err)
}
packetLen := binary.BigEndian.Uint16(l)
if packetLen > dns.MaxMsgSize {
return nil, errTooLarge
}
b = make([]byte, packetLen)
_, err = io.ReadFull(conn, b)
if err != nil {
return nil, fmt.Errorf("reading msg: %w", err)
}
return b, nil
}
// Writes a response to the TCP (or TLS) client
func (p *Proxy) respondTCP(d *DNSContext) error {
resp := d.Res
conn := d.Conn
if resp == nil {
// If no response has been written, close the connection right away
return conn.Close()
}
bytes, err := resp.Pack()
if err != nil {
return fmt.Errorf("packing message: %w", err)
}
err = writePrefixed(bytes, conn)
if err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("writing message: %w", err)
}
return nil
}
// writePrefixed writes a DNS message to a TCP connection it first writes
// a 2-byte prefix followed by the message itself.
func writePrefixed(b []byte, conn net.Conn) (err error) {
l := make([]byte, 2)
binary.BigEndian.PutUint16(l, uint16(len(b)))
_, err = (&net.Buffers{l, b}).WriteTo(conn)
return err
}
07070100000077000081A4000000000000000000000001679A649F00000692000000000000000000000000000000000000002900000000dnsproxy-0.75.0/proxy/server_tcp_test.gopackage proxy
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"testing"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func TestTcpProxy(t *testing.T) {
dnsProxy := mustStartDefaultProxy(t)
// Create a DNS-over-TCP client connection
addr := dnsProxy.Addr(ProtoTCP)
conn, err := dns.Dial("tcp", addr.String())
require.NoError(t, err)
sendTestMessages(t, conn)
}
func TestTlsProxy(t *testing.T) {
serverConfig, caPem := newTLSConfig(t)
dnsProxy := mustNew(t, &Config{
Logger: slogutil.NewDiscardLogger(),
TLSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
HTTPSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
QUICListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TLSConfig: serverConfig,
UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
})
// Start listening
ctx := context.Background()
err := dnsProxy.Start(ctx)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })
roots := x509.NewCertPool()
roots.AppendCertsFromPEM(caPem)
tlsConfig := &tls.Config{ServerName: tlsServerName, RootCAs: roots}
// Create a DNS-over-TLS client connection
addr := dnsProxy.Addr(ProtoTLS)
conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
require.NoError(t, err)
sendTestMessages(t, conn)
}
07070100000078000081A4000000000000000000000001679A649F00001102000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/server_udp.gopackage proxy
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/syncutil"
"github.com/miekg/dns"
)
func (p *Proxy) createUDPListeners(ctx context.Context) (err error) {
for _, a := range p.UDPListenAddr {
var pc *net.UDPConn
pc, sErr := p.udpCreate(ctx, a)
if sErr != nil {
return fmt.Errorf("listening on udp addr %s: %w", a, sErr)
}
p.udpListen = append(p.udpListen, pc)
}
return nil
}
// udpCreate - create a UDP listening socket
func (p *Proxy) udpCreate(ctx context.Context, udpAddr *net.UDPAddr) (*net.UDPConn, error) {
p.logger.InfoContext(ctx, "creating udp server socket", "addr", udpAddr)
packetConn, err := proxynetutil.ListenConfig(p.logger).ListenPacket(
ctx,
bootstrap.NetworkUDP,
udpAddr.String(),
)
if err != nil {
return nil, fmt.Errorf("listening to udp socket: %w", err)
}
udpListen := packetConn.(*net.UDPConn)
if p.Config.UDPBufferSize > 0 {
err = udpListen.SetReadBuffer(p.Config.UDPBufferSize)
if err != nil {
_ = udpListen.Close()
return nil, fmt.Errorf("setting udp buf size: %w", err)
}
}
err = proxynetutil.UDPSetOptions(udpListen)
if err != nil {
_ = udpListen.Close()
return nil, fmt.Errorf("setting udp opts: %w", err)
}
p.logger.InfoContext(ctx, "listening to udp", "addr", udpListen.LocalAddr())
return udpListen, nil
}
// udpPacketLoop listens for incoming UDP packets.
//
// See also the comment on Proxy.requestsSema.
func (p *Proxy) udpPacketLoop(conn *net.UDPConn, reqSema syncutil.Semaphore) {
p.logger.Info("entering udp listener loop", "addr", conn.LocalAddr())
b := make([]byte, dns.MaxMsgSize)
for p.isStarted() {
n, localIP, remoteAddr, err := proxynetutil.UDPRead(conn, b, p.udpOOBSize)
// The documentation says to handle the packet even if err occurs.
if n > 0 {
// Make a copy of all bytes because ReadFrom() will overwrite the
// contents of b on the next call. We need that contents to sustain
// the call because we're handling them in goroutines.
packet := make([]byte, n)
copy(packet, b)
// TODO(d.kolyshev): Pass and use context from above.
sErr := reqSema.Acquire(context.Background())
if sErr != nil {
p.logger.Error("acquiring semaphore", "proto", ProtoUDP, slogutil.KeyError, sErr)
break
}
go func() {
defer reqSema.Release()
p.udpHandlePacket(packet, localIP, remoteAddr, conn)
}()
}
if err != nil {
logUDPConnError(err, conn, p.logger)
break
}
}
}
// logUDPConnError writes suitable log message for given err.
func logUDPConnError(err error, conn *net.UDPConn, l *slog.Logger) {
if errors.Is(err, net.ErrClosed) {
l.Debug("udp connection closed", "addr", conn.LocalAddr())
} else {
l.Error("reading from udp", slogutil.KeyError, err)
}
}
// udpHandlePacket processes the incoming UDP packet and sends a DNS response
func (p *Proxy) udpHandlePacket(
packet []byte,
localIP netip.Addr,
remoteAddr *net.UDPAddr,
conn *net.UDPConn,
) {
p.logger.Debug("handling new udp packet", "raddr", remoteAddr)
req := &dns.Msg{}
err := req.Unpack(packet)
if err != nil {
p.logger.Error("unpacking udp packet", slogutil.KeyError, err)
return
}
d := p.newDNSContext(ProtoUDP, req, netutil.NetAddrToAddrPort(remoteAddr))
d.Conn = conn
d.localIP = localIP
err = p.handleDNSRequest(d)
if err != nil {
p.logger.Debug("handling dns request", "proto", d.Proto, slogutil.KeyError, err)
}
}
// Writes a response to the UDP client
func (p *Proxy) respondUDP(d *DNSContext) error {
resp := d.Res
if resp == nil {
// Do nothing if no response has been written
return nil
}
bytes, err := resp.Pack()
if err != nil {
return fmt.Errorf("packing message: %w", err)
}
conn := d.Conn.(*net.UDPConn)
rAddr := net.UDPAddrFromAddrPort(d.Addr)
n, err := proxynetutil.UDPWrite(bytes, conn, rAddr, d.localIP)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return fmt.Errorf("writing message: %w", err)
}
if n != len(bytes) {
return fmt.Errorf("udpWrite() returned with %d != %d", n, len(bytes))
}
return nil
}
07070100000079000081A4000000000000000000000001679A649F00000160000000000000000000000000000000000000002900000000dnsproxy-0.75.0/proxy/server_udp_test.gopackage proxy
import (
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func TestUdpProxy(t *testing.T) {
dnsProxy := mustStartDefaultProxy(t)
// Create a DNS-over-UDP client connection
addr := dnsProxy.Addr(ProtoUDP)
conn, err := dns.Dial("udp", addr.String())
require.NoError(t, err)
sendTestMessages(t, conn)
}
0707010000007A000081A4000000000000000000000001679A649F000013F0000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/stats.gopackage proxy
import (
"fmt"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// upstreamWithStats is a wrapper around the [upstream.Upstream] interface that
// gathers statistics.
type upstreamWithStats struct {
// upstream is the upstream DNS resolver.
upstream upstream.Upstream
// err is the DNS lookup error, if any.
err error
// queryDuration is the duration of the successful DNS lookup.
queryDuration time.Duration
}
// type check
var _ upstream.Upstream = (*upstreamWithStats)(nil)
// Exchange implements the [upstream.Upstream] for *upstreamWithStats.
func (u *upstreamWithStats) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
start := time.Now()
resp, err = u.upstream.Exchange(req)
u.err = err
u.queryDuration = time.Since(start)
return resp, err
}
// Address implements the [upstream.Upstream] for *upstreamWithStats.
func (u *upstreamWithStats) Address() (addr string) {
return u.upstream.Address()
}
// Close implements the [upstream.Upstream] for *upstreamWithStats.
func (u *upstreamWithStats) Close() (err error) {
return u.upstream.Close()
}
// upstreamsWithStats takes a list of upstreams, wraps each upstream with
// [upstreamWithStats] to gather statistics, and returns the wrapped upstreams.
func upstreamsWithStats(upstreams []upstream.Upstream) (wrapped []upstream.Upstream) {
wrapped = make([]upstream.Upstream, 0, len(upstreams))
for _, u := range upstreams {
wrapped = append(wrapped, &upstreamWithStats{upstream: u})
}
return wrapped
}
// QueryStatistics contains the DNS query statistics for both the upstream and
// fallback DNS servers.
type QueryStatistics struct {
main []*UpstreamStatistics
fallback []*UpstreamStatistics
}
// cachedQueryStatistics returns the DNS query statistics for cached queries.
func cachedQueryStatistics(addr string) (s *QueryStatistics) {
return &QueryStatistics{
main: []*UpstreamStatistics{{
Address: addr,
IsCached: true,
}},
}
}
// Main returns the DNS query statistics for the upstream DNS servers.
func (s *QueryStatistics) Main() (us []*UpstreamStatistics) {
return s.main
}
// Fallback returns the DNS query statistics for the fallback DNS servers.
func (s *QueryStatistics) Fallback() (us []*UpstreamStatistics) {
return s.fallback
}
// collectQueryStats gathers the statistics from the wrapped upstreams.
// resolver is an upstream DNS resolver that successfully resolved the request,
// if any. Provided upstreams must be of type [*upstreamWithStats]. unwrapped
// is the unwrapped resolver, see [upstreamWithStats.upstream]. The returned
// statistics depend on whether the DNS request was successfully resolved and
// the upstream mode, see [DNSContext.QueryStatistics].
func collectQueryStats(
mode UpstreamMode,
resolver upstream.Upstream,
upstreams []upstream.Upstream,
fallbacks []upstream.Upstream,
) (unwrapped upstream.Upstream, stats *QueryStatistics) {
var wrapped *upstreamWithStats
if resolver != nil {
var ok bool
wrapped, ok = resolver.(*upstreamWithStats)
if !ok {
// Should never happen.
panic(fmt.Errorf("unexpected type %T", resolver))
}
unwrapped = wrapped.upstream
}
// The DNS query was not resolved.
if wrapped == nil {
return nil, &QueryStatistics{
main: collectUpstreamStats(upstreams...),
fallback: collectUpstreamStats(fallbacks...),
}
}
// The DNS query was successfully resolved by main resolver and the upstream
// mode is [UpstreamModeFastestAddr].
if mode == UpstreamModeFastestAddr && len(fallbacks) == 0 {
return unwrapped, &QueryStatistics{
main: collectUpstreamStats(upstreams...),
}
}
// The DNS query was resolved by fallback resolver.
if len(fallbacks) > 0 {
return unwrapped, &QueryStatistics{
main: collectUpstreamStats(upstreams...),
fallback: collectUpstreamStats(wrapped),
}
}
// The DNS query was successfully resolved by main resolver.
return unwrapped, &QueryStatistics{
main: collectUpstreamStats(wrapped),
}
}
// UpstreamStatistics contains the DNS query statistics.
type UpstreamStatistics struct {
// Error is the DNS lookup error, if any.
Error error
// Address is the address of the upstream DNS resolver.
//
// TODO(s.chzhen): Use [upstream.Upstream] when [cacheItem] starts to
// contain one.
Address string
// QueryDuration is the duration of the successful DNS lookup.
QueryDuration time.Duration
// IsCached indicates whether the response was served from a cache.
IsCached bool
}
// collectUpstreamStats gathers the upstream statistics from the list of wrapped
// upstreams. upstreams must be of type *upstreamWithStats.
func collectUpstreamStats(upstreams ...upstream.Upstream) (stats []*UpstreamStatistics) {
stats = make([]*UpstreamStatistics, 0, len(upstreams))
for _, u := range upstreams {
w, ok := u.(*upstreamWithStats)
if !ok {
// Should never happen.
panic(fmt.Errorf("unexpected type %T", u))
}
stats = append(stats, &UpstreamStatistics{
Error: w.err,
Address: w.Address(),
QueryDuration: w.queryDuration,
})
}
return stats
}
0707010000007B000081A4000000000000000000000001679A649F00001E58000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/stats_test.gopackage proxy_test
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCollectQueryStats(t *testing.T) {
const (
listenIP = "127.0.0.1"
)
var (
testReq = &dns.Msg{
Question: []dns.Question{{
Name: "test.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
defaultTrustedProxies netutil.SubnetSet = netutil.SliceSubnetSet{
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("::0/0"),
}
localhostAnyPort = netip.MustParseAddrPort(netutil.JoinHostPort(listenIP, 0))
)
ups := &dnsproxytest.FakeUpstream{
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return (&dns.Msg{}).SetReply(req), nil
},
OnAddress: func() (addr string) { return "upstream" },
OnClose: func() (err error) { return nil },
}
failUps := &dnsproxytest.FakeUpstream{
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return nil, errors.Error("exchange error")
},
OnAddress: func() (addr string) { return "fail.upstream" },
OnClose: func() (err error) { return nil },
}
conf := &proxy.Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
}
testCases := []struct {
wantErr assert.ErrorAssertionFunc
wantMainErr assert.BoolAssertionFunc
wantFallbackErr assert.BoolAssertionFunc
config *proxy.UpstreamConfig
fallbackConfig *proxy.UpstreamConfig
name string
mode proxy.UpstreamMode
wantMainCount int
wantFallbackCount int
}{{
wantErr: assert.NoError,
wantMainErr: assert.False,
wantFallbackErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "load_balance_success",
mode: proxy.UpstreamModeLoadBalance,
wantMainCount: 1,
wantFallbackCount: 0,
}, {
wantErr: assert.Error,
wantMainErr: assert.True,
wantFallbackErr: assert.True,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps},
},
name: "load_balance_bad",
mode: proxy.UpstreamModeLoadBalance,
wantMainCount: 1,
wantFallbackCount: 2,
}, {
wantErr: assert.NoError,
wantMainErr: assert.False,
wantFallbackErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "parallel_success",
mode: proxy.UpstreamModeParallel,
wantMainCount: 1,
wantFallbackCount: 0,
}, {
wantErr: assert.NoError,
wantMainErr: assert.True,
wantFallbackErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "parallel_bad_fallback_success",
mode: proxy.UpstreamModeParallel,
wantMainCount: 1,
wantFallbackCount: 1,
}, {
wantErr: assert.Error,
wantMainErr: assert.True,
wantFallbackErr: assert.True,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps, failUps},
},
name: "parallel_bad",
mode: proxy.UpstreamModeParallel,
wantMainCount: 2,
wantFallbackCount: 3,
}, {
wantErr: assert.NoError,
wantMainErr: assert.False,
wantFallbackErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "fastest_single_success",
mode: proxy.UpstreamModeFastestAddr,
wantMainCount: 1,
wantFallbackCount: 0,
}, {
wantErr: assert.NoError,
wantMainErr: assert.False,
wantFallbackErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups, ups},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "fastest_multiple_success",
mode: proxy.UpstreamModeFastestAddr,
wantMainCount: 2,
wantFallbackCount: 0,
}, {
wantErr: assert.NoError,
wantMainErr: assert.True,
wantFallbackErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "fastest_mixed_success",
mode: proxy.UpstreamModeFastestAddr,
wantMainCount: 2,
wantFallbackCount: 0,
}, {
wantErr: assert.Error,
wantMainErr: assert.True,
wantFallbackErr: assert.True,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps, failUps},
},
name: "fastest_multiple_bad",
mode: proxy.UpstreamModeFastestAddr,
wantMainCount: 2,
wantFallbackCount: 3,
}, {
wantErr: assert.NoError,
wantMainErr: assert.True,
wantFallbackErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "fastest_bad_fallback_success",
mode: proxy.UpstreamModeFastestAddr,
wantMainCount: 2,
wantFallbackCount: 1,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf.UpstreamConfig = tc.config
conf.Fallbacks = tc.fallbackConfig
conf.UpstreamMode = tc.mode
p, err := proxy.New(conf)
require.NoError(t, err)
d := &proxy.DNSContext{Req: testReq}
err = p.Resolve(d)
tc.wantErr(t, err)
stats := d.QueryStatistics()
assertQueryStats(
t,
stats,
tc.wantMainCount,
tc.wantMainErr,
tc.wantFallbackCount,
tc.wantFallbackErr,
)
})
}
}
// assertQueryStats asserts the statistics using the provided parameters.
func assertQueryStats(
t *testing.T,
stats *proxy.QueryStatistics,
wantMainCount int,
wantMainErr assert.BoolAssertionFunc,
wantFallbackCount int,
wantFallbackErr assert.BoolAssertionFunc,
) {
t.Helper()
main := stats.Main()
assert.Lenf(t, main, wantMainCount, "main stats count")
fallback := stats.Fallback()
assert.Lenf(t, fallback, wantFallbackCount, "fallback stats count")
wantMainErr(t, isErrorInStats(main), "main err")
wantFallbackErr(t, isErrorInStats(fallback), "fallback err")
}
// isErrorInStats is a helper function for tests that returns true if the
// upstream statistics contain an DNS lookup error.
func isErrorInStats(stats []*proxy.UpstreamStatistics) (ok bool) {
for _, u := range stats {
if u.Error != nil {
return true
}
}
return false
}
0707010000007C000081A4000000000000000000000001679A649F000005ED000000000000000000000000000000000000002600000000dnsproxy-0.75.0/proxy/upstreammode.gopackage proxy
import (
"encoding"
"fmt"
)
// UpstreamMode is an enumeration of upstream mode representations.
//
// TODO(d.kolyshev): Set uint8 as underlying type.
type UpstreamMode string
const (
// UpstreamModeLoadBalance is the default upstream mode. It balances the
// upstreams load.
UpstreamModeLoadBalance UpstreamMode = "load_balance"
// UpstreamModeParallel makes server to query all configured upstream
// servers in parallel.
UpstreamModeParallel UpstreamMode = "parallel"
// UpstreamModeFastestAddr controls whether the server should respond to A
// or AAAA requests only with the fastest IP address detected by ICMP
// response time or TCP connection time.
UpstreamModeFastestAddr UpstreamMode = "fastest_addr"
)
// type check
var _ encoding.TextUnmarshaler = (*UpstreamMode)(nil)
// UnmarshalText implements [encoding.TextUnmarshaler] interface for
// *UpstreamMode.
func (m *UpstreamMode) UnmarshalText(b []byte) (err error) {
switch um := UpstreamMode(b); um {
case
UpstreamModeLoadBalance,
UpstreamModeParallel,
UpstreamModeFastestAddr:
*m = um
default:
return fmt.Errorf(
"invalid upstream mode %q, supported: %q, %q, %q",
b,
UpstreamModeLoadBalance,
UpstreamModeParallel,
UpstreamModeFastestAddr,
)
}
return nil
}
// type check
var _ encoding.TextMarshaler = UpstreamMode("")
// MarshalText implements [encoding.TextMarshaler] interface for UpstreamMode.
func (m UpstreamMode) MarshalText() (text []byte, err error) {
return []byte(m), nil
}
0707010000007D000081A4000000000000000000000001679A649F0000014C000000000000000000000000000000000000002B00000000dnsproxy-0.75.0/proxy/upstreammode_test.gopackage proxy_test
import (
"testing"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/testutil"
)
func TestUpstreamMode_encoding(t *testing.T) {
t.Parallel()
v := proxy.UpstreamModeLoadBalance
testutil.AssertMarshalText(t, "load_balance", &v)
testutil.AssertUnmarshalText(t, "load_balance", &v)
}
0707010000007E000081A4000000000000000000000001679A649F000038C1000000000000000000000000000000000000002300000000dnsproxy-0.75.0/proxy/upstreams.gopackage proxy
import (
"fmt"
"io"
"log/slog"
"maps"
"slices"
"strings"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
)
// UnqualifiedNames is a key for [UpstreamConfig.DomainReservedUpstreams] map to
// specify the upstreams only used for resolving domain names consisting of a
// single label.
const UnqualifiedNames = "unqualified_names"
// UpstreamConfig maps domain names to upstreams.
type UpstreamConfig struct {
// DomainReservedUpstreams maps the domains to the upstreams.
DomainReservedUpstreams map[string][]upstream.Upstream
// SpecifiedDomainUpstreams maps the specific domain names to the upstreams.
SpecifiedDomainUpstreams map[string][]upstream.Upstream
// SubdomainExclusions is set of domains with subdomains exclusions.
SubdomainExclusions *container.MapSet[string]
// Upstreams is a list of default upstreams.
Upstreams []upstream.Upstream
}
// type check
var _ io.Closer = (*UpstreamConfig)(nil)
// ParseUpstreamsConfig returns an UpstreamConfig and nil error if the upstream
// configuration is valid. Otherwise returns a partially filled UpstreamConfig
// and wrapped error containing lines with errors. It also skips empty lines
// and comments (lines starting with "#").
//
// # Simple upstreams
//
// Single upstream per line. For example:
//
// 1.2.3.4
// 3.4.5.6
//
// # Domain specific upstreams
//
// - reserved upstreams: [/domain1/../domainN/]<upstreamString>
// - subdomains only upstreams: [/*.domain1/../*.domainN]<upstreamString>
//
// Where <upstreamString> is one or many upstreams separated by space (e.g.
// `1.1.1.1` or `1.1.1.1 2.2.2.2`).
//
// More specific domains take priority over less specific domains. To exclude
// more specific domains from reserved upstreams querying you should use the
// following syntax:
//
// [/domain1/../domainN/]#
//
// So the following config:
//
// [/host.com/]1.2.3.4
// [/www.host.com/]2.3.4.5"
// [/maps.host.com/news.host.com/]#
// 3.4.5.6
//
// will send queries for *.host.com to 1.2.3.4. Except for *.www.host.com,
// which will go to 2.3.4.5. And *.maps.host.com or *.news.host.com, which
// will go to default server 3.4.5.6 with all other domains.
//
// To exclude top level domain from reserved upstreams querying you could use
// the following:
//
// '[/*.domain.com/]<upstreamString>'
//
// So the following config:
//
// [/*.domain.com/]1.2.3.4
// 3.4.5.6
//
// will send queries for all subdomains *.domain.com to 1.2.3.4, but domain.com
// query will be sent to default server 3.4.5.6 as every other query.
//
// TODO(e.burkov): Consider supporting multiple upstreams in a single line for
// default upstream syntax.
func ParseUpstreamsConfig(
lines []string,
opts *upstream.Options,
) (conf *UpstreamConfig, err error) {
if opts == nil {
opts = &upstream.Options{}
}
if opts.Logger == nil {
opts.Logger = slog.Default()
}
p := &configParser{
options: opts,
logger: opts.Logger,
upstreamsIndex: map[string]upstream.Upstream{},
domainReservedUpstreams: map[string][]upstream.Upstream{},
specifiedDomainUpstreams: map[string][]upstream.Upstream{},
subdomainsOnlyUpstreams: map[string][]upstream.Upstream{},
subdomainsOnlyExclusions: container.NewMapSet[string](),
}
return p.parse(lines)
}
// ParseError is an error which contains an index of the line of the upstream
// list.
type ParseError struct {
// err is the original error.
err error
// Idx is an index of the lines. See [ParseUpstreamsConfig].
Idx int
}
// type check
var _ error = (*ParseError)(nil)
// Error implements the [error] interface for *ParseError.
func (e *ParseError) Error() (msg string) {
return fmt.Sprintf("parsing error at index %d: %s", e.Idx, e.err)
}
// type check
var _ errors.Wrapper = (*ParseError)(nil)
// Unwrap implements the [errors.Wrapper] interface for *ParseError.
func (e *ParseError) Unwrap() (unwrapped error) { return e.err }
// configParser collects the results of parsing an upstream config.
type configParser struct {
// options contains upstream properties.
options *upstream.Options
// logger is used for logging during parsing. It's never nil.
logger *slog.Logger
// upstreamsIndex is used to avoid creating duplicates of upstreams.
upstreamsIndex map[string]upstream.Upstream
// domainReservedUpstreams is a map of reserved domains and lists of
// corresponding upstreams.
domainReservedUpstreams map[string][]upstream.Upstream
// specifiedDomainUpstreams is a map of excluded domains and lists of
// corresponding upstreams.
specifiedDomainUpstreams map[string][]upstream.Upstream
// subdomainsOnlyUpstreams is a map of wildcard subdomains and lists of
// corresponding upstreams.
subdomainsOnlyUpstreams map[string][]upstream.Upstream
// subdomainsOnlyExclusions is set of domains with subdomains exclusions.
subdomainsOnlyExclusions *container.MapSet[string]
// upstreams is a list of default upstreams.
upstreams []upstream.Upstream
}
// parse returns UpstreamConfig and error if upstreams configuration is invalid.
func (p *configParser) parse(lines []string) (c *UpstreamConfig, err error) {
var errs []error
for i, l := range lines {
if err = p.parseLine(i, l); err != nil {
errs = append(errs, &ParseError{Idx: i, err: err})
}
}
for host, ups := range p.subdomainsOnlyUpstreams {
// Rewrite ups for wildcard subdomains to remove upper level domains
// specs.
p.domainReservedUpstreams[host] = ups
}
return &UpstreamConfig{
Upstreams: p.upstreams,
DomainReservedUpstreams: p.domainReservedUpstreams,
SpecifiedDomainUpstreams: p.specifiedDomainUpstreams,
SubdomainExclusions: p.subdomainsOnlyExclusions,
}, errors.Join(errs...)
}
// parseLine returns an error if upstream configuration line is invalid.
func (p *configParser) parseLine(idx int, confLine string) (err error) {
if len(confLine) == 0 || confLine[0] == '#' {
return nil
}
upstreams, domains, err := splitConfigLine(confLine)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
if upstreams[0] == "#" && len(domains) > 0 {
p.excludeFromReserved(domains)
return nil
}
for _, u := range upstreams {
err = p.specifyUpstream(domains, u, idx)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
}
return nil
}
// splitConfigLine parses upstream configuration line and returns list upstream
// addresses (one or many), list of domains for which this upstream is reserved
// (may be nil). It returns an error if the upstream format is incorrect.
func splitConfigLine(confLine string) (upstreams, domains []string, err error) {
if !strings.HasPrefix(confLine, "[/") {
return []string{confLine}, nil, nil
}
domainsLine, upstreamsLine, found := strings.Cut(confLine[len("[/"):], "/]")
if !found || upstreamsLine == "" {
return nil, nil, errors.Error("wrong upstream format")
}
// split domains list
for _, confHost := range strings.Split(domainsLine, "/") {
if confHost == "" {
// empty domain specification means `unqualified names only`
domains = append(domains, UnqualifiedNames)
continue
}
host := strings.TrimPrefix(confHost, "*.")
if err = netutil.ValidateDomainName(host); err != nil {
return nil, nil, err
}
domains = append(domains, strings.ToLower(confHost+"."))
}
return strings.Fields(upstreamsLine), domains, nil
}
// specifyUpstream specifies the upstream for domains.
func (p *configParser) specifyUpstream(domains []string, u string, idx int) (err error) {
dnsUpstream, ok := p.upstreamsIndex[u]
// TODO(e.burkov): Improve identifying duplicate upstreams.
if !ok {
// create an upstream
dnsUpstream, err = upstream.AddressToUpstream(u, p.options.Clone())
if err != nil {
return fmt.Errorf("cannot prepare the upstream: %s", err)
}
// save to the index
p.upstreamsIndex[u] = dnsUpstream
}
addr := dnsUpstream.Address()
if len(domains) == 0 {
// TODO(s.chzhen): Handle duplicates.
p.upstreams = append(p.upstreams, dnsUpstream)
// TODO(s.chzhen): Logs without index.
p.logger.Debug("set upstream", "idx", idx, "addr", addr)
} else {
p.includeToReserved(dnsUpstream, domains)
p.logger.Debug(
"upstream is reserved",
"idx", idx,
"addr", addr,
"domains_num", len(domains),
)
}
return nil
}
// excludeFromReserved excludes more specific domains from reserved upstreams
// querying.
func (p *configParser) excludeFromReserved(domains []string) {
for _, host := range domains {
if trimmed := strings.TrimPrefix(host, "*."); trimmed != host {
p.subdomainsOnlyExclusions.Add(trimmed)
p.subdomainsOnlyUpstreams[trimmed] = nil
continue
}
p.domainReservedUpstreams[host] = nil
p.specifiedDomainUpstreams[host] = nil
}
}
// includeToReserved includes domains to reserved upstreams querying.
func (p *configParser) includeToReserved(dnsUpstream upstream.Upstream, domains []string) {
for _, host := range domains {
if strings.HasPrefix(host, "*.") {
host = host[len("*."):]
p.subdomainsOnlyExclusions.Add(host)
p.logger.Debug("domain is added to exclusions list", "domain", host)
p.subdomainsOnlyUpstreams[host] = append(p.subdomainsOnlyUpstreams[host], dnsUpstream)
} else {
p.specifiedDomainUpstreams[host] = append(p.specifiedDomainUpstreams[host], dnsUpstream)
}
p.domainReservedUpstreams[host] = append(p.domainReservedUpstreams[host], dnsUpstream)
}
}
// validate returns an error if the upstreams aren't configured properly. c
// considered valid if it contains at least a single default upstream. Empty c
// causes [upstream.ErrNoUpstreams].
func (uc *UpstreamConfig) validate() (err error) {
const (
errNilConf errors.Error = "upstream config is nil"
errNoDefault errors.Error = "no default upstreams specified"
)
switch {
case uc == nil:
return errNilConf
case len(uc.Upstreams) > 0:
return nil
case len(uc.DomainReservedUpstreams) == 0 && len(uc.SpecifiedDomainUpstreams) == 0:
return upstream.ErrNoUpstreams
default:
return errNoDefault
}
}
// ValidatePrivateConfig returns an error if uc isn't valid, or, treated as
// private upstreams configuration, contains specifications for invalid domains.
func ValidatePrivateConfig(uc *UpstreamConfig, privateSubnets netutil.SubnetSet) (err error) {
if err = uc.validate(); err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
var errs []error
for _, domain := range slices.Sorted(maps.Keys(uc.DomainReservedUpstreams)) {
pref, extErr := netutil.ExtractReversedAddr(domain)
switch {
case extErr != nil:
// Don't wrap the error since it's informative enough as is.
errs = append(errs, extErr)
case pref.Bits() == 0:
// Allow private subnets for subdomains of the root domain.
case !privateSubnets.Contains(pref.Addr()):
errs = append(errs, fmt.Errorf("reversed subnet in %q is not private", domain))
default:
// Go on.
}
}
return errors.Join(errs...)
}
// getUpstreamsForDomain returns the upstreams specified for resolving fqdn. It
// always returns the default set of upstreams if the domain is not reserved for
// any other upstreams.
//
// More specific domains take priority over less specific ones. For example, if
// the upstreams specified for the following domains:
//
// - host.com
// - www.host.com
//
// The request for mail.host.com will be resolved using the upstreams specified
// for host.com.
func (uc *UpstreamConfig) getUpstreamsForDomain(fqdn string) (ups []upstream.Upstream) {
if len(uc.DomainReservedUpstreams) == 0 {
return uc.Upstreams
}
fqdn = strings.ToLower(fqdn)
if uc.SubdomainExclusions.Has(fqdn) {
return uc.lookupSubdomainExclusion(fqdn)
}
ups, ok := uc.lookupUpstreams(fqdn)
if ok {
return ups
}
if _, fqdn, _ = strings.Cut(fqdn, "."); fqdn == "" {
fqdn = UnqualifiedNames
}
for fqdn != "" {
if ups, ok = uc.lookupUpstreams(fqdn); ok {
return ups
}
_, fqdn, _ = strings.Cut(fqdn, ".")
}
return uc.Upstreams
}
// getUpstreamsForDS is like [getUpstreamsForDomain], but intended for DS
// queries only, so that it matches fqdn without the first label.
//
// A DS RRset SHOULD be present at a delegation point when the child zone is
// signed. The DS RRset MAY contain multiple records, each referencing a public
// key in the child zone used to verify the RRSIGs in that zone. All DS RRsets
// in a zone MUST be signed, and DS RRsets MUST NOT appear at a zone's apex.
//
// See https://datatracker.ietf.org/doc/html/rfc4035#section-2.4
func (uc *UpstreamConfig) getUpstreamsForDS(fqdn string) (ups []upstream.Upstream) {
_, fqdn, _ = strings.Cut(fqdn, ".")
if fqdn == "" {
return uc.Upstreams
}
return uc.getUpstreamsForDomain(fqdn)
}
// lookupSubdomainExclusion returns upstreams for the host from subdomain
// exclusions list.
func (uc *UpstreamConfig) lookupSubdomainExclusion(host string) (u []upstream.Upstream) {
ups, ok := uc.SpecifiedDomainUpstreams[host]
if ok && len(ups) > 0 {
return ups
}
// Check if there is a spec for upper level domain.
h := strings.SplitAfterN(host, ".", 2)
ups, ok = uc.DomainReservedUpstreams[h[1]]
if ok && len(ups) > 0 {
return ups
}
return uc.Upstreams
}
// lookupUpstreams returns upstreams for a domain name. It returns default
// upstream list for domain name excluded by domain reserved upstreams.
func (uc *UpstreamConfig) lookupUpstreams(name string) (ups []upstream.Upstream, ok bool) {
ups, ok = uc.DomainReservedUpstreams[name]
if !ok {
return ups, false
}
if len(ups) == 0 {
// The domain has been excluded from reserved upstreams querying.
ups = uc.Upstreams
}
return ups, true
}
// Close implements the io.Closer interface for *UpstreamConfig.
func (uc *UpstreamConfig) Close() (err error) {
closeErrs := closeAll(nil, uc.Upstreams...)
for _, specUps := range []map[string][]upstream.Upstream{
uc.DomainReservedUpstreams,
uc.SpecifiedDomainUpstreams,
} {
domains := make([]string, 0, len(specUps))
for domain := range specUps {
domains = append(domains, domain)
}
slices.SortStableFunc(domains, strings.Compare)
for _, domain := range domains {
closeErrs = closeAll(closeErrs, specUps[domain]...)
}
}
if len(closeErrs) > 0 {
return fmt.Errorf("failed to close some upstreams: %w", errors.Join(closeErrs...))
}
return nil
}
0707010000007F000081A4000000000000000000000001679A649F00002D00000000000000000000000000000000000000003100000000dnsproxy-0.75.0/proxy/upstreams_internal_test.gopackage proxy
import (
"testing"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(e.burkov): Call [testing.T.Parallel] in this file.
// Domains specifications and their questions used in tests of [UpstreamConfig].
const (
unqualifiedFQDN = "unqualified."
unspecifiedFQDN = "unspecified.domain."
topLevelDomain = "example"
topLevelFQDN = topLevelDomain + "."
firstLevelDomain = "name." + topLevelDomain
firstLevelFQDN = firstLevelDomain + "."
wildcardFirstLevelDomain = "*." + topLevelDomain
subDomain = "sub." + firstLevelDomain
subFQDN = subDomain + "."
generalDomain = "general." + firstLevelDomain
generalFQDN = generalDomain + "."
wildcardDomain = "*." + firstLevelDomain
anotherSubFQDN = "another." + firstLevelDomain + "."
)
// Upstream URLs used in tests of [UpstreamConfig].
const (
generalUpstream = "tcp://general.upstream:53"
unqualifiedUpstream = "tcp://unqualified.upstream:53"
tldUpstream = "tcp://tld.upstream:53"
domainUpstream = "tcp://domain.upstream:53"
wildcardUpstream = "tcp://wildcard.upstream:53"
subdomainUpstream = "tcp://subdomain.upstream:53"
)
// testUpstreamConfigLines is the common set of upstream configurations used in
// tests of [UpstreamConfig].
var testUpstreamConfigLines = []string{
generalUpstream,
"[//]" + unqualifiedUpstream,
"[/" + topLevelDomain + "/]" + tldUpstream,
"[/" + wildcardFirstLevelDomain + "/]#",
"[/" + firstLevelDomain + "/]" + domainUpstream,
"[/" + wildcardDomain + "/]" + wildcardUpstream,
"[/" + generalDomain + "/]#",
"[/" + subDomain + "/]" + subdomainUpstream,
}
func TestUpstreamConfig_GetUpstreamsForDomain(t *testing.T) {
t.Parallel()
config, err := ParseUpstreamsConfig(testUpstreamConfigLines, nil)
require.NoError(t, err)
testCases := []struct {
name string
in string
want []string
}{{
name: "unspecified",
in: unspecifiedFQDN,
want: []string{generalUpstream},
}, {
name: "unqualified",
in: unqualifiedFQDN,
want: []string{unqualifiedUpstream},
}, {
name: "tld",
in: topLevelFQDN,
want: []string{tldUpstream},
}, {
name: "unspecified_subdomain",
in: unspecifiedFQDN + topLevelFQDN,
want: []string{generalUpstream},
}, {
name: "domain",
in: firstLevelFQDN,
want: []string{domainUpstream},
}, {
name: "wildcard",
in: anotherSubFQDN,
want: []string{wildcardUpstream},
}, {
name: "general",
in: generalFQDN,
want: []string{generalUpstream},
}, {
name: "subdomain",
in: subFQDN,
want: []string{subdomainUpstream},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ups := config.getUpstreamsForDomain(tc.in)
assertUpstreamsAddrs(t, ups, tc.want)
})
}
}
func TestUpstreamConfig_GetUpstreamsForDS(t *testing.T) {
t.Parallel()
config, err := ParseUpstreamsConfig(testUpstreamConfigLines, nil)
require.NoError(t, err)
testCases := []struct {
name string
in string
want []string
}{{
name: "unspecified",
in: unspecifiedFQDN,
want: []string{unqualifiedUpstream},
}, {
name: "unqualified",
in: unqualifiedFQDN,
want: []string{generalUpstream},
}, {
name: "tld",
in: topLevelFQDN,
want: []string{generalUpstream},
}, {
name: "unspecified_subdomain",
in: unspecifiedFQDN + topLevelFQDN,
want: []string{generalUpstream},
}, {
name: "domain",
in: firstLevelFQDN,
want: []string{tldUpstream},
}, {
name: "wildcard",
in: anotherSubFQDN,
want: []string{domainUpstream},
}, {
name: "general",
in: "label." + generalFQDN,
want: []string{generalUpstream},
}, {
name: "subdomain",
in: "label." + subFQDN,
want: []string{subdomainUpstream},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ups := config.getUpstreamsForDS(tc.in)
assertUpstreamsAddrs(t, ups, tc.want)
})
}
}
func TestUpstreamConfig_Validate(t *testing.T) {
testCases := []struct {
name string
wantErr error
in []string
}{{
name: "empty",
wantErr: upstream.ErrNoUpstreams,
in: []string{},
}, {
name: "nil",
wantErr: upstream.ErrNoUpstreams,
in: nil,
}, {
name: "valid",
wantErr: nil,
in: []string{
"udp://upstream.example:53",
},
}, {
name: "no_default",
wantErr: errors.Error("no default upstreams specified"),
in: []string{
"[/domain.example/]udp://upstream.example:53",
"[/another.domain.example/]#",
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, err := ParseUpstreamsConfig(tc.in, nil)
require.NoError(t, err)
assert.ErrorIs(t, c.validate(), tc.wantErr)
})
}
t.Run("actual_nil", func(t *testing.T) {
assert.ErrorIs(t, (*UpstreamConfig)(nil).validate(), errors.Error("upstream config is nil"))
})
}
func TestValidatePrivateConfig(t *testing.T) {
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
wantErr string
u string
}{{
name: "success_address",
wantErr: ``,
u: "[/1.0.0.127.in-addr.arpa/]#",
}, {
name: "success_subnet",
wantErr: ``,
u: "[/127.in-addr.arpa/]#",
}, {
name: "success_v4_family",
wantErr: ``,
u: "[/in-addr.arpa/]#",
}, {
name: "success_v6_family",
wantErr: ``,
u: "[/ip6.arpa/]#",
}, {
name: "bad_arpa_domain",
wantErr: `bad arpa domain name "arpa": not a reversed ip network`,
u: "[/arpa/]#",
}, {
name: "not_arpa_subnet",
wantErr: `bad arpa domain name "hello.world": not a reversed ip network`,
u: "[/hello.world/]#",
}, {
name: "non-private_arpa_address",
wantErr: `reversed subnet in "1.2.3.4.in-addr.arpa." is not private`,
u: "[/1.2.3.4.in-addr.arpa/]#",
}, {
name: "non-private_arpa_subnet",
wantErr: `reversed subnet in "128.in-addr.arpa." is not private`,
u: "[/128.in-addr.arpa/]#",
}, {
name: "several_bad",
wantErr: `reversed subnet in "1.2.3.4.in-addr.arpa." is not private` +
"\n" + `bad arpa domain name "non.arpa": not a reversed ip network`,
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
}, {
name: "partial_good",
wantErr: "",
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
}}
for _, tc := range testCases {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
upsConf, err := ParseUpstreamsConfig(set, nil)
require.NoError(t, err)
testutil.AssertErrorMsg(t, tc.wantErr, ValidatePrivateConfig(upsConf, ss))
})
}
}
func TestGetUpstreamsForDomainWithoutDuplicates(t *testing.T) {
upstreams := []string{"[/example.com/]1.1.1.1", "[/example.org/]1.1.1.1"}
config, err := ParseUpstreamsConfig(upstreams, &upstream.Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: false,
Bootstrap: nil,
Timeout: testTimeout,
})
assert.NoError(t, err)
assert.Len(t, config.Upstreams, 0)
assert.Len(t, config.DomainReservedUpstreams, 2)
u1 := config.DomainReservedUpstreams["example.com."][0]
u2 := config.DomainReservedUpstreams["example.org."][0]
// Check that the very same Upstream instance is used for both domains.
assert.Same(t, u1, u2)
}
func TestGetUpstreamsForDomain_wildcards(t *testing.T) {
conf := []string{
"0.0.0.1",
"[/a.x/]0.0.0.2",
"[/*.a.x/]0.0.0.3",
"[/b.a.x/]0.0.0.4",
"[/*.b.a.x/]0.0.0.5",
"[/*.x.z/]0.0.0.6",
"[/c.b.a.x/]#",
}
uconf, err := ParseUpstreamsConfig(conf, nil)
require.NoError(t, err)
testCases := []struct {
name string
in string
want []string
}{{
name: "default",
in: "d.x.",
want: []string{"0.0.0.1:53"},
}, {
name: "specified_one",
in: "a.x.",
want: []string{"0.0.0.2:53"},
}, {
name: "wildcard",
in: "c.a.x.",
want: []string{"0.0.0.3:53"},
}, {
name: "specified_two",
in: "b.a.x.",
want: []string{"0.0.0.4:53"},
}, {
name: "wildcard_two",
in: "d.b.a.x.",
want: []string{"0.0.0.5:53"},
}, {
name: "specified_three",
in: "c.b.a.x.",
want: []string{"0.0.0.1:53"},
}, {
name: "specified_four",
in: "d.c.b.a.x.",
want: []string{"0.0.0.1:53"},
}, {
name: "unspecified_wildcard",
in: "x.z.",
want: []string{"0.0.0.1:53"},
}, {
name: "unspecified_wildcard_sub",
in: "a.x.z.",
want: []string{"0.0.0.6:53"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ups := uconf.getUpstreamsForDomain(tc.in)
assertUpstreamsAddrs(t, ups, tc.want)
})
}
}
func TestGetUpstreamsForDomain_sub_wildcards(t *testing.T) {
conf := []string{
"0.0.0.1",
"[/a.x/]0.0.0.2",
"[/*.a.x/]0.0.0.3",
"[/*.b.a.x/]0.0.0.5",
}
uconf, err := ParseUpstreamsConfig(conf, nil)
require.NoError(t, err)
testCases := []struct {
name string
in string
want []string
}{{
name: "specified",
in: "a.x.",
want: []string{"0.0.0.2:53"},
}, {
name: "wildcard",
in: "c.a.x.",
want: []string{"0.0.0.3:53"},
}, {
name: "sub_spec_ignore",
in: "b.a.x.",
want: []string{"0.0.0.3:53"},
}, {
name: "sub_spec_wildcard",
in: "d.b.a.x.",
want: []string{"0.0.0.5:53"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ups := uconf.getUpstreamsForDomain(tc.in)
assertUpstreamsAddrs(t, ups, tc.want)
})
}
}
func TestGetUpstreamsForDomain_default_wildcards(t *testing.T) {
conf := []string{
"127.0.0.1:5301",
"[/example.org/]127.0.0.1:5302",
"[/*.example.org/]127.0.0.1:5303",
"[/www.example.org/]127.0.0.1:5304",
"[/*.www.example.org/]#",
}
uconf, err := ParseUpstreamsConfig(conf, nil)
require.NoError(t, err)
testCases := []struct {
name string
in string
want []string
}{{
name: "domain",
in: "example.org.",
want: []string{"127.0.0.1:5302"},
}, {
name: "sub_wildcard",
in: "sub.example.org.",
want: []string{"127.0.0.1:5303"},
}, {
name: "spec_sub",
in: "www.example.org.",
want: []string{"127.0.0.1:5304"},
}, {
name: "def_wildcard",
in: "abc.www.example.org.",
want: []string{"127.0.0.1:5301"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ups := uconf.getUpstreamsForDomain(tc.in)
assertUpstreamsAddrs(t, ups, tc.want)
})
}
}
// upsSink is the typed sink variable for the result of benchmarked function.
var upsSink []upstream.Upstream
func BenchmarkGetUpstreamsForDomain(b *testing.B) {
upstreams := []string{
"[/google.com/local/]4.3.2.1",
"[/www.google.com//]1.2.3.4",
"[/maps.google.com/]#",
"[/www.google.com/]tls://1.1.1.1",
}
config, _ := ParseUpstreamsConfig(upstreams, &upstream.Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: false,
Bootstrap: nil,
Timeout: testTimeout,
})
domains := []string{
"www.google.com.",
"www2.google.com.",
"internal.local.",
"google.",
"maps.google.com.",
}
l := len(domains)
for i := range b.N {
upsSink = config.getUpstreamsForDomain(domains[i%l])
}
}
// assertUpstreamsAddrs checks the addresses of ups to exactly match want.
func assertUpstreamsAddrs(tb testing.TB, ups []upstream.Upstream, want []string) {
tb.Helper()
require.Len(tb, ups, len(want))
for i, up := range ups {
assert.Equalf(tb, want[i], up.Address(), "at index %d", i)
}
}
07070100000080000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001A00000000dnsproxy-0.75.0/proxyutil07070100000081000081A4000000000000000000000001679A649F000002A8000000000000000000000000000000000000002100000000dnsproxy-0.75.0/proxyutil/dns.go// Package proxyutil contains helper functions that are used in all other
// dnsproxy packages.
package proxyutil
import (
"encoding/binary"
"net/netip"
"github.com/miekg/dns"
)
// AddPrefix adds a 2-byte prefix with the DNS message length.
func AddPrefix(b []byte) (m []byte) {
m = make([]byte, 2+len(b))
binary.BigEndian.PutUint16(m, uint16(len(b)))
copy(m[2:], b)
return m
}
// IPFromRR returns the IP address from rr if any.
func IPFromRR(rr dns.RR) (ip netip.Addr) {
var data []byte
switch rr := rr.(type) {
case *dns.A:
data = rr.A.To4()
case *dns.AAAA:
data = rr.AAAA
default:
return netip.Addr{}
}
ip, _ = netip.AddrFromSlice(data)
return ip
}
07070100000082000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001800000000dnsproxy-0.75.0/scripts07070100000083000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001E00000000dnsproxy-0.75.0/scripts/hooks07070100000084000081ED000000000000000000000001679A649F00000867000000000000000000000000000000000000002900000000dnsproxy-0.75.0/scripts/hooks/pre-commit#!/bin/sh
set -e -f -u
# This comment is used to simplify checking local copies of the script.
# Bump this number every time a significant change is made to this
# script.
#
# AdGuard-Project-Version: 2
# TODO(a.garipov): Add pre-merge-commit.
# Only show interactive prompts if there a terminal is attached to
# stdout. While this technically doesn't guarantee that reading from
# /dev/tty works, this should work reasonably well on all of our
# supported development systems and in most terminal emulators.
is_tty='0'
if [ -t '1' ]
then
is_tty='1'
fi
readonly is_tty
# prompt is a helper that prompts the user for interactive input if that
# can be done. If there is no terminal attached, it sleeps for two
# seconds, giving the programmer some time to react, and returns with
# a zero exit code.
prompt() {
if [ "$is_tty" -eq '0' ]
then
sleep 2
return 0
fi
while true
do
printf 'commit anyway? y/[n]: '
read -r ans < /dev/tty
case "$ans"
in
('y'|'Y')
break
;;
(''|'n'|'N')
exit 1
;;
(*)
continue
;;
esac
done
}
# Warn the programmer about unstaged changes and untracked files, but do
# not fail the commit, because those changes might be temporary or for
# a different branch.
awk_prog='substr($2, 2, 1) != "." { print $9; } $1 == "?" { print $2; }'
readonly awk_prog
unstaged="$( git status --porcelain=2 | awk "$awk_prog" )"
readonly unstaged
if [ "$unstaged" != "" ]
then
printf 'WARNING: you have unstaged changes:\n\n%s\n\n' "$unstaged"
prompt
fi
# Warn the programmer about temporary todos, but do not fail the commit,
# because the commit could be in a temporary branch.
temp_todos="$( git grep -e 'TODO.*!!' -- ':!scripts/hooks/pre-commit' || : )"
readonly temp_todos
if [ "$temp_todos" != "" ]
then
printf 'WARNING: you have temporary todos:\n\n%s\n\n' "$temp_todos"
prompt
fi
verbose="${VERBOSE:-0}"
readonly verbose
if [ "$( git diff --cached --name-only -- '*.md' '*.yaml' '*.yml' )" ]
then
make VERBOSE="$verbose" txt-lint
fi
if [ "$( git diff --cached --name-only -- '*.go' '*.mod' '*.sh' 'Makefile' )" ]
then
make VERBOSE="$verbose" go-os-check go-lint go-test
fi
07070100000085000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001D00000000dnsproxy-0.75.0/scripts/make07070100000086000081A4000000000000000000000001679A649F00000BC4000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/scripts/make/build-docker.sh#!/bin/sh
verbose="${VERBOSE:-0}"
if [ "$verbose" -gt '0' ]
then
set -x
debug_flags='--debug=1'
else
set +x
debug_flags='--debug=0'
fi
readonly debug_flags
set -e -f -u
# Require these to be set.
commit="${REVISION:?please set REVISION}"
dist_dir="${DIST_DIR:?please set DIST_DIR}"
version="${VERSION:?please set VERSION}"
readonly commit dist_dir version
# Allow users to use sudo.
sudo_cmd="${SUDO:-}"
readonly sudo_cmd
docker_platforms="\
linux/386,\
linux/amd64,\
linux/arm/v6,\
linux/arm/v7,\
linux/arm64,\
linux/ppc64le"
readonly docker_platforms
build_date="$( date -u +'%Y-%m-%dT%H:%M:%SZ' )"
readonly build_date
# Set DOCKER_IMAGE_NAME to 'adguard/dnsproxy' if you want (and are allowed)
# to push to DockerHub.
docker_image_name="${DOCKER_IMAGE_NAME:-dnsproxy-dev}"
readonly docker_image_name
# Set DOCKER_OUTPUT to 'type=image,name=adguard/dnsproxy,push=true' if you
# want (and are allowed) to push to DockerHub.
#
# If you want to inspect the resulting image using commands like "docker image
# ls", change type to docker and also set docker_platforms to a single platform.
#
# See https://github.com/docker/buildx/issues/166.
docker_output="${DOCKER_OUTPUT:-type=image,name=${docker_image_name},push=false}"
readonly docker_output
docker_version_tag="--tag=${docker_image_name}:${version}"
docker_channel_tag="--tag=${docker_image_name}:latest"
# If version is set to 'dev' or empty, only set the version tag and avoid
# polluting the "latest" tag.
if [ "${version:-}" = 'dev' ] || [ "${version:-}" = '' ]
then
docker_channel_tag=""
fi
readonly docker_version_tag docker_channel_tag
# Copy the binaries into a new directory under new names, so that it's easier to
# COPY them later. DO NOT remove the trailing underscores. See file
# docker/Dockerfile.
dist_docker="${dist_dir}/docker"
readonly dist_docker
mkdir -p "$dist_docker"
cp "${dist_dir}/linux-386/dnsproxy"\
"${dist_docker}/dnsproxy_linux_386_"
cp "${dist_dir}/linux-amd64/dnsproxy"\
"${dist_docker}/dnsproxy_linux_amd64_"
cp "${dist_dir}/linux-arm64/dnsproxy"\
"${dist_docker}/dnsproxy_linux_arm64_"
cp "${dist_dir}/linux-arm6/dnsproxy"\
"${dist_docker}/dnsproxy_linux_arm_v6"
cp "${dist_dir}/linux-arm7/dnsproxy"\
"${dist_docker}/dnsproxy_linux_arm_v7"
cp "${dist_dir}/linux-ppc64le/dnsproxy"\
"${dist_docker}/dnsproxy_linux_ppc64le_"
# Prepare the default configuration for the Docker image.
cp ./config.yaml.dist "${dist_docker}/config.yaml"
# Don't use quotes with $docker_version_tag and $docker_channel_tag, because we
# want word splitting and or an empty space if tags are empty.
#
# TODO(a.garipov): Once flag --tag of docker buildx build supports commas, use
# them instead.
$sudo_cmd docker\
"$debug_flags"\
buildx build\
--build-arg BUILD_DATE="$build_date"\
--build-arg DIST_DIR="$dist_dir"\
--build-arg VCS_REF="$commit"\
--build-arg VERSION="$version"\
--output "$docker_output"\
--platform "$docker_platforms"\
$docker_version_tag\
$docker_channel_tag\
-f ./docker/Dockerfile\
.
07070100000087000081A4000000000000000000000001679A649F00000CEE000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/scripts/make/build-release.sh#!/bin/sh
verbose="${VERBOSE:-0}"
readonly verbose
if [ "$verbose" -gt '2' ]
then
env
set -x
elif [ "$verbose" -gt '1' ]
then
set -x
fi
set -e -f -u
log() {
if [ "$verbose" -gt '0' ]
then
# Don't use quotes to get word splitting.
echo "$1" 1>&2
fi
}
log 'starting to build dnsproxy release'
version="${VERSION:-}"
readonly version
log "version '$version'"
dist="${DIST_DIR:-build}"
readonly dist
out="${OUT:-dnsproxy}"
log "checking tools"
for tool in tar zip
do
if ! command -v "$tool" > /dev/null
then
log "tool '$tool' not found"
exit 1
fi
done
# Data section. Arrange data into space-separated tables for read -r to read.
# Use 0 for missing values.
# os arch arm mips
platforms="\
darwin amd64 0 0
darwin arm64 0 0
freebsd 386 0 0
freebsd amd64 0 0
freebsd arm 5 0
freebsd arm 6 0
freebsd arm 7 0
freebsd arm64 0 0
linux 386 0 0
linux amd64 0 0
linux arm 5 0
linux arm 6 0
linux arm 7 0
linux arm64 0 0
linux mips 0 softfloat
linux mips64 0 softfloat
linux mips64le 0 softfloat
linux mipsle 0 softfloat
linux ppc64le 0 0
openbsd amd64 0 0
openbsd arm64 0 0
windows 386 0 0
windows amd64 0 0
windows arm64 0 0"
readonly platforms
build() {
# Get the arguments. Here and below, use the "build_" prefix for all
# variables local to function build.
build_dir="${dist}/${1}"\
build_name="$1"\
build_os="$2"\
build_arch="$3"\
build_arm="$4"\
build_mips="$5"\
;
# Use the ".exe" filename extension if we build a Windows release.
if [ "$build_os" = 'windows' ]
then
build_output="./${build_dir}/${out}.exe"
else
build_output="./${build_dir}/${out}"
fi
mkdir -p "./${build_dir}"
# Build the binary.
#
# Set GOARM and GOMIPS to an empty string if $build_arm and $build_mips
# are zero by removing the zero as if it's a prefix.
#
# Don't use quotes with $build_par because we want an empty space if
# parallelism wasn't set.
env\
GOARCH="$build_arch"\
GOARM="${build_arm#0}"\
GOMIPS="${build_mips#0}"\
GOOS="$os"\
VERBOSE="$(( verbose - 1 ))"\
VERSION="$version"\
OUT="$build_output"\
sh ./scripts/make/go-build.sh\
;
log "$build_output"
# Prepare the build directory for archiving.
cp ./LICENSE ./README.md "$build_dir"
# Make archives. Windows prefers ZIP archives; the rest, gzipped tarballs.
case "$build_os"
in
('windows')
build_archive="./${dist}/${out}-${build_name}-${version}.zip"
# TODO(a.garipov): Find an option similar to the -C option of tar for
# zip.
( cd "${dist}" && zip -9 -q -r "../${build_archive}" "./${build_name}" )
;;
(*)
build_archive="./${dist}/${out}-${build_name}-${version}.tar.gz"
tar -C "./${dist}" -c -f - "./${build_name}" | gzip -9 - > "$build_archive"
;;
esac
log "$build_archive"
}
log "starting builds"
# Go over all platforms defined in the space-separated table above, tweak the
# values where necessary, and feed to build.
echo "$platforms" | while read -r os arch arm mips
do
case "$arch"
in
(arm)
name="${os}-${arch}${arm}"
;;
(*)
name="${os}-${arch}"
;;
esac
build "$name" "$os" "$arch" "$arm" "$mips"
done
log "finished"
07070100000088000081A4000000000000000000000001679A649F00000B62000000000000000000000000000000000000002900000000dnsproxy-0.75.0/scripts/make/go-build.sh#!/bin/sh
# dnsproxy build script
#
# The commentary in this file is written with the assumption that the reader
# only has superficial knowledge of the POSIX shell language and alike.
# Experienced readers may find it overly verbose.
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 1
# The default verbosity level is 0. Show every command that is run and every
# package that is processed if the caller requested verbosity level greater than
# 0. Also show subcommands if the requested verbosity level is greater than 1.
# Otherwise, do nothing.
verbose="${VERBOSE:-0}"
readonly verbose
if [ "$verbose" -gt '1' ]
then
env
set -x
v_flags='-v=1'
x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
set -x
v_flags='-v=1'
x_flags='-x=0'
else
set +x
v_flags='-v=0'
x_flags='-x=0'
fi
readonly x_flags v_flags
# Exit the script if a pipeline fails (-e), prevent accidental filename
# expansion (-f), and consider undefined variables as errors (-u).
set -e -f -u
# Allow users to override the go command from environment. For example, to
# build two releases with two different Go versions and test the difference.
go="${GO:-go}"
readonly go
# Set the build parameters unless already set.
branch="${BRANCH:-$( git rev-parse --abbrev-ref HEAD )}"
revision="${REVISION:-$( git rev-parse --short HEAD )}"
version="${VERSION:-0}"
readonly branch revision version
# Set date and time of the latest commit unless already set.
committime="${SOURCE_DATE_EPOCH:-$( git log -1 --pretty=%ct )}"
readonly committime
# Compile them in.
version_pkg='github.com/AdguardTeam/dnsproxy/internal/version'
ldflags="-s -w"
ldflags="${ldflags} -X ${version_pkg}.branch=${branch}"
ldflags="${ldflags} -X ${version_pkg}.committime=${committime}"
ldflags="${ldflags} -X ${version_pkg}.revision=${revision}"
ldflags="${ldflags} -X ${version_pkg}.version=${version}"
readonly ldflags version_pkg
# Allow users to limit the build's parallelism.
parallelism="${PARALLELISM:-}"
readonly parallelism
# Use GOFLAGS for -p, because -p=0 simply disables the build instead of leaving
# the default value.
if [ "${parallelism}" != '' ]
then
GOFLAGS="${GOFLAGS:-} -p=${parallelism}"
fi
readonly GOFLAGS
export GOFLAGS
# Allow users to specify a different output name.
out="${OUT:-dnsproxy}"
readonly out
o_flags="-o=${out}"
readonly o_flags
# Allow users to enable the race detector. Unfortunately, that means that cgo
# must be enabled.
if [ "${RACE:-0}" -eq '0' ]
then
CGO_ENABLED='0'
race_flags='--race=0'
else
CGO_ENABLED='1'
race_flags='--race=1'
fi
readonly CGO_ENABLED race_flags
export CGO_ENABLED
GO111MODULE='on'
export GO111MODULE
if [ "$verbose" -gt '0' ]
then
"$go" env
fi
"$go" build\
--ldflags="$ldflags"\
"$race_flags"\
--trimpath\
"$o_flags"\
"$v_flags"\
"$x_flags"
07070100000089000081A4000000000000000000000001679A649F000001D8000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/go-deps.sh#!/bin/sh
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 1
verbose="${VERBOSE:-0}"
readonly verbose
if [ "$verbose" -gt '1' ]
then
env
set -x
x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
set -x
x_flags='-x=0'
else
set +x
x_flags='-x=0'
fi
readonly x_flags
set -e -f -u
go="${GO:-go}"
readonly go
"$go" mod download "$x_flags"
0707010000008A000081A4000000000000000000000001679A649F0000112C000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/go-lint.sh#!/bin/sh
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 7
verbose="${VERBOSE:-0}"
readonly verbose
if [ "$verbose" -gt '0' ]
then
set -x
fi
# Set $EXIT_ON_ERROR to zero to see all errors.
if [ "${EXIT_ON_ERROR:-1}" -eq '0' ]
then
set +e
else
set -e
fi
set -f -u
# Source the common helpers, including not_found and run_linter.
. ./scripts/make/helper.sh
# Simple analyzers
# blocklist_imports is a simple check against unwanted packages. The following
# packages are banned:
#
# * Package errors is replaced by our own package in the
# github.com/AdguardTeam/golibs module.
#
# * Package io/ioutil is soft-deprecated.
#
# * Package log and github.com/AdguardTeam/golibs/log are replaced by
# stdlib's new package log/slog and AdGuard's new utilities package
# github.com/AdguardTeam/golibs/logutil/slogutil.
#
# * Package reflect is often an overkill, and for deep comparisons there are
# much better functions in module github.com/google/go-cmp. Which is
# already our indirect dependency and which may or may not enter the stdlib
# at some point.
#
# See https://github.com/golang/go/issues/45200.
#
# * Package sort is replaced by package slices.
#
# * Package unsafe is… unsafe.
#
# * Package golang.org/x/exp/slices has been moved into stdlib.
#
# * Package golang.org/x/net/context has been moved into stdlib.
#
# Currently, the only standard exception are files generated from protobuf
# schemas, which use package reflect. If your project needs more exceptions,
# add and document them.
#
# TODO(a.garipov): Add deprecated package golang.org/x/exp/maps once all
# projects switch to Go 1.23.
blocklist_imports() {
git grep\
-e '[[:space:]]"errors"$'\
-e '[[:space:]]"github.com/AdguardTeam/golibs/log"$'\
-e '[[:space:]]"golang.org/x/exp/slices"$'\
-e '[[:space:]]"golang.org/x/net/context"$'\
-e '[[:space:]]"io/ioutil"$'\
-e '[[:space:]]"log"$'\
-e '[[:space:]]"reflect"$'\
-e '[[:space:]]"sort"$'\
-e '[[:space:]]"unsafe"$'\
-n\
-- '*.go'\
':!*.pb.go'\
| sed -e 's/^\([^[:space:]]\+\)\(.*\)$/\1 blocked import:\2/'\
|| exit 0
}
# method_const is a simple check against the usage of some raw strings and
# numbers where one should use named constants.
method_const() {
git grep -F\
-e '"DELETE"'\
-e '"GET"'\
-e '"PATCH"'\
-e '"POST"'\
-e '"PUT"'\
-n\
-- '*.go'\
| sed -e 's/^\([^[:space:]]\+\)\(.*\)$/\1 http method literal:\2/'\
|| exit 0
}
# underscores is a simple check against Go filenames with underscores. Add new
# build tags and OS as you go. The main goal of this check is to discourage the
# use of filenames like client_manager.go.
underscores() {
underscore_files="$(
git ls-files '*_*.go'\
| grep -F\
-e '_darwin.go'\
-e '_generate.go'\
-e '_linux.go'\
-e '_others.go'\
-e '_plan9.go'\
-e '_test.go'\
-e '_unix.go'\
-e '_windows.go'\
-e '_dnscrypt.go'\
-e '_https.go'\
-e '_quic.go'\
-e '_tcp.go'\
-e '_udp.go'\
-v\
| sed -e 's/./\t\0/'
)"
readonly underscore_files
if [ "$underscore_files" != '' ]
then
echo 'found file names with underscores:'
echo "$underscore_files"
fi
}
# TODO(a.garipov): Add an analyzer to look for `fallthrough`, `goto`, and `new`?
# Checks
run_linter -e blocklist_imports
run_linter -e method_const
run_linter -e underscores
run_linter -e gofumpt --extra -e -l .
# TODO(a.garipov): golint is deprecated, find a suitable replacement.
run_linter "${GO:-go}" vet ./...
run_linter govulncheck ./...
run_linter gocyclo --over 10 .
run_linter gocognit --over 10 .
run_linter ineffassign ./...
run_linter unparam ./...
git ls-files -- 'Makefile' '*.conf' '*.go' '*.mod' '*.sh' '*.yaml' '*.yml'\
| xargs misspell --error\
| sed -e 's/^/misspell: /'
run_linter looppointer ./...
run_linter nilness ./...
run_linter fieldalignment ./...
run_linter -e shadow --strict ./...
# TODO(a.garipov): Re-enable G115.
run_linter gosec --exclude G115 --quiet ./...
run_linter errcheck ./...
staticcheck_matrix='
darwin: GOOS=darwin
freebsd: GOOS=freebsd
linux: GOOS=linux
openbsd: GOOS=openbsd
windows: GOOS=windows
'
readonly staticcheck_matrix
echo "$staticcheck_matrix" | run_linter staticcheck --matrix ./...
0707010000008B000081A4000000000000000000000001679A649F000003E8000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/go-test.sh#!/bin/sh
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 1
verbose="${VERBOSE:-0}"
readonly verbose
# Verbosity levels:
# 0 = Don't print anything except for errors.
# 1 = Print commands, but not nested commands.
# 2 = Print everything.
if [ "$verbose" -gt '1' ]
then
set -x
v_flags='-v=1'
x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
set -x
v_flags='-v=1'
x_flags='-x=0'
else
set +x
v_flags='-v=0'
x_flags='-x=0'
fi
readonly v_flags x_flags
set -e -f -u
if [ "${RACE:-1}" -eq '0' ]
then
race_flags='--race=0'
else
race_flags='--race=1'
fi
readonly race_flags
go="${GO:-go}"
count_flags='--count=1'
shuffle_flags='--shuffle=on'
timeout_flags="${TIMEOUT_FLAGS:---timeout=2m}"
readonly go count_flags shuffle_flags timeout_flags
"$go" test\
"$count_flags"\
"$race_flags"\
"$shuffle_flags"\
"$timeout_flags"\
"$v_flags"\
"$x_flags"\
./...
0707010000008C000081A4000000000000000000000001679A649F00000766000000000000000000000000000000000000002900000000dnsproxy-0.75.0/scripts/make/go-tools.sh#!/bin/sh
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 3
verbose="${VERBOSE:-0}"
readonly verbose
if [ "$verbose" -gt '1' ]
then
set -x
v_flags='-v=1'
x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
set -x
v_flags='-v=1'
x_flags='-x=0'
else
set +x
v_flags='-v=0'
x_flags='-x=0'
fi
readonly v_flags x_flags
set -e -f -u
go="${GO:-go}"
readonly go
# Remove only the actual binaries in the bin/ directory, as developers may add
# their own scripts there. Most commonly, a script named “go” for tools that
# call the go binary and need a particular version.
rm -f\
bin/errcheck\
bin/fieldalignment\
bin/gocognit\
bin/gocyclo\
bin/gofumpt\
bin/gosec\
bin/govulncheck\
bin/ineffassign\
bin/looppointer\
bin/misspell\
bin/nilness\
bin/shadow\
bin/staticcheck\
bin/unparam\
;
# Reset GOARCH and GOOS to make sure we install the tools for the native
# architecture even when we're cross-compiling the main binary, and also to
# prevent the "cannot install cross-compiled binaries when GOBIN is set" error.
env\
GOARCH=""\
GOBIN="${PWD}/bin"\
GOOS=""\
GOWORK='off'\
"$go" install\
--modfile=./internal/tools/go.mod\
"$v_flags"\
"$x_flags"\
github.com/fzipp/gocyclo/cmd/gocyclo\
github.com/golangci/misspell/cmd/misspell\
github.com/gordonklaus/ineffassign\
github.com/kisielk/errcheck\
github.com/kyoh86/looppointer/cmd/looppointer\
github.com/securego/gosec/v2/cmd/gosec\
github.com/uudashr/gocognit/cmd/gocognit\
golang.org/x/tools/go/analysis/passes/fieldalignment/cmd/fieldalignment\
golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness\
golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow\
golang.org/x/vuln/cmd/govulncheck\
honnef.co/go/tools/cmd/staticcheck\
mvdan.cc/gofumpt\
mvdan.cc/unparam\
;
0707010000008D000081A4000000000000000000000001679A649F000001F7000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/scripts/make/go-upd-tools.sh#!/bin/sh
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 2
verbose="${VERBOSE:-0}"
readonly verbose
if [ "$verbose" -gt '1' ]
then
env
set -x
x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
set -x
x_flags='-x=0'
else
set +x
x_flags='-x=0'
fi
readonly x_flags
set -e -f -u
go="${GO:-go}"
readonly go
cd ./internal/tools/
"$go" get -u "$x_flags"
"$go" mod tidy
0707010000008E000081A4000000000000000000000001679A649F00000651000000000000000000000000000000000000002700000000dnsproxy-0.75.0/scripts/make/helper.sh#!/bin/sh
# Common script helpers
#
# This file contains common script helpers. It should be sourced in scripts
# right after the initial environment processing.
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 3
# Deferred helpers
not_found_msg='
looks like a binary not found error.
make sure you have installed the linter binaries using:
$ make go-tools
'
readonly not_found_msg
not_found() {
if [ "$?" -eq '127' ]
then
# Code 127 is the exit status a shell uses when a command or a file is
# not found, according to the Bash Hackers wiki.
#
# See https://wiki.bash-hackers.org/dict/terms/exit_status.
echo "$not_found_msg" 1>&2
fi
}
trap not_found EXIT
# Helpers
# run_linter runs the given linter with two additions:
#
# 1. If the first argument is "-e", run_linter exits with a nonzero exit code
# if there is anything in the command's combined output.
#
# 2. In any case, run_linter adds the program's name to its combined output.
run_linter() (
set +e
if [ "${VERBOSE:-0}" -lt '2' ]
then
set +x
fi
cmd="${1:?run_linter: provide a command}"
shift
exit_on_output='0'
if [ "$cmd" = '-e' ]
then
exit_on_output='1'
cmd="${1:?run_linter: provide a command}"
shift
fi
readonly cmd
output="$( "$cmd" "$@" )"
exitcode="$?"
readonly output
if [ "$output" != '' ]
then
echo "$output" | sed -e "s/^/${cmd}: /"
if [ "$exitcode" -eq '0' ] && [ "$exit_on_output" -eq '1' ]
then
exitcode='1'
fi
fi
return "$exitcode"
)
0707010000008F000081A4000000000000000000000001679A649F0000018B000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/md-lint.sh#!/bin/sh
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 2
verbose="${VERBOSE:-0}"
readonly verbose
set -e -f -u
if [ "$verbose" -gt '0' ]
then
set -x
fi
# NOTE: Adjust for your project.
# markdownlint\
# ./README.md\
# ;
# TODO(e.burkov): Lint README.md.
07070100000090000081A4000000000000000000000001679A649F000001CE000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/sh-lint.sh#!/bin/sh
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 2
verbose="${VERBOSE:-0}"
readonly verbose
# Don't use -f, because we use globs in this script.
set -e -u
if [ "$verbose" -gt '0' ]
then
set -x
fi
# NOTE: Adjust for your project.
shellcheck -e 'SC2250' -f 'gcc' -o 'all' -x --\
./scripts/hooks/*\
./scripts/make/*\
;
07070100000091000081A4000000000000000000000001679A649F000006A9000000000000000000000000000000000000002900000000dnsproxy-0.75.0/scripts/make/txt-lint.sh#!/bin/sh
# This comment is used to simplify checking local copies of the script. Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 5
verbose="${VERBOSE:-0}"
readonly verbose
if [ "$verbose" -gt '0' ]
then
set -x
fi
# Set $EXIT_ON_ERROR to zero to see all errors.
if [ "${EXIT_ON_ERROR:-1}" -eq '0' ]
then
set +e
else
set -e
fi
# We don't need glob expansions and we want to see errors about unset variables.
set -f -u
# Source the common helpers, including not_found.
. ./scripts/make/helper.sh
# Simple analyzers
# trailing_newlines is a simple check that makes sure that all plain-text files
# have a trailing newlines to make sure that all tools work correctly with them.
trailing_newlines() (
nl="$( printf "\n" )"
readonly nl
# NOTE: Adjust for your project.
git ls-files\
':!*.bmp'\
':!*.jpg'\
':!*.mmdb'\
':!*.png'\
':!*.tar.gz'\
':!*.webp'\
':!*.zip'\
| while read -r f
do
final_byte="$( tail -c -1 "$f" )"
if [ "$final_byte" != "$nl" ]
then
printf '%s: must have a trailing newline\n' "$f"
fi
done
)
# trailing_whitespace is a simple check that makes sure that there are no
# trailing whitespace in plain-text files.
trailing_whitespace() {
# NOTE: Adjust for your project.
git ls-files\
':!*.bmp'\
':!*.jpg'\
':!*.mmdb'\
':!*.png'\
':!*.tar.gz'\
':!*.webp'\
':!*.zip'\
| while read -r f
do
grep -e '[[:space:]]$' -n -- "$f"\
| sed -e "s:^:${f}\::" -e 's/ \+$/>>>&<<</'
done
}
run_linter -e trailing_newlines
run_linter -e trailing_whitespace
git ls-files -- '*.conf' '*.md' '*.txt' '*.yaml' '*.yml'\
| xargs misspell --error\
| sed -e 's/^/misspell: /'
07070100000092000081A4000000000000000000000001679A649F0000025B000000000000000000000000000000000000002100000000dnsproxy-0.75.0/staticcheck.conf# This comment is used to simplify checking local copies of the staticcheck
# configuration. Bump this number every time a significant change is made to
# this file.
#
# AdGuard-Project-Version: 1
checks = ["all"]
initialisms = [
# See https://github.com/dominikh/go-tools/blob/master/config/config.go.
#
# Do not add "PTR" since we use "Ptr" as a suffix.
"inherit"
, "ASN"
, "DHCP"
, "DNSSEC"
# E.g. SentryDSN.
, "DSN"
, "ECS"
, "EDNS"
, "MX"
, "QUIC"
, "RA"
, "RRSIG"
, "RTT"
, "SDNS"
, "SLAAC"
, "SOA"
, "SVCB"
, "TLD"
, "WHOIS"
]
dot_import_whitelist = []
http_status_code_whitelist = []
07070100000093000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001900000000dnsproxy-0.75.0/upstream07070100000094000081A4000000000000000000000001679A649F000010CD000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/dnscrypt.gopackage upstream
import (
"fmt"
"io"
"log/slog"
"net/url"
"os"
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/ameshkov/dnscrypt/v2"
"github.com/miekg/dns"
)
// dnsCrypt implements the [Upstream] interface for the DNSCrypt protocol.
type dnsCrypt struct {
// mu protects client and serverInfo.
mu *sync.RWMutex
// client stores the DNSCrypt client properties.
client *dnscrypt.Client
// resolverInfo stores the DNSCrypt server properties.
resolverInfo *dnscrypt.ResolverInfo
// addr is the DNSCrypt server URL.
addr *url.URL
// logger is used for exchange logging. It is never nil.
logger *slog.Logger
// verifyCert is a callback that verifies the resolver's certificate.
verifyCert func(cert *dnscrypt.Cert) (err error)
// timeout is the timeout for the DNS requests.
timeout time.Duration
}
// newDNSCrypt returns a new DNSCrypt Upstream.
func newDNSCrypt(addr *url.URL, opts *Options) (u *dnsCrypt) {
return &dnsCrypt{
mu: &sync.RWMutex{},
addr: addr,
logger: opts.Logger,
verifyCert: opts.VerifyDNSCryptCertificate,
timeout: opts.Timeout,
}
}
// type check
var _ Upstream = (*dnsCrypt)(nil)
// Address implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Address() string { return p.addr.String() }
// Exchange implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
resp, err = p.exchangeDNSCrypt(req)
if errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, io.EOF) {
// If request times out, it is possible that the server configuration
// has been changed. It is safe to assume that the key was rotated, see
// https://dnscrypt.pl/2017/02/26/how-key-rotation-is-automated.
// Re-fetch the server certificate info for new requests to not fail.
_, _, err = p.resetClient()
if err != nil {
return nil, err
}
return p.exchangeDNSCrypt(req)
}
return resp, err
}
// Close implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Close() (err error) {
return nil
}
// exchangeDNSCrypt attempts to send the DNS query and returns the response.
func (p *dnsCrypt) exchangeDNSCrypt(req *dns.Msg) (resp *dns.Msg, err error) {
var client *dnscrypt.Client
var resolverInfo *dnscrypt.ResolverInfo
func() {
p.mu.RLock()
defer p.mu.RUnlock()
client, resolverInfo = p.client, p.resolverInfo
}()
// Check the client and server info are set and the certificate is not
// expired, since any of these cases require a client reset.
//
// TODO(ameshkov): Consider using [time.Time] for [dnscrypt.Cert.NotAfter].
switch {
case
client == nil,
resolverInfo == nil,
resolverInfo.ResolverCert.NotAfter < uint32(time.Now().Unix()):
client, resolverInfo, err = p.resetClient()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
default:
// Go on.
}
resp, err = client.Exchange(req, resolverInfo)
if resp != nil && resp.Truncated {
q := &req.Question[0]
p.logger.Debug(
"dnscrypt received truncated, falling back to tcp",
"addr", p.addr,
"question", q,
)
tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: networkTCP}
resp, err = tcpClient.Exchange(req, resolverInfo)
}
if err == nil && resp != nil && resp.Id != req.Id {
err = dns.ErrId
}
return resp, err
}
// resetClient renews the DNSCrypt client and server properties and also sets
// those to nil on fail.
func (p *dnsCrypt) resetClient() (client *dnscrypt.Client, ri *dnscrypt.ResolverInfo, err error) {
addr := p.Address()
defer func() {
p.mu.Lock()
defer p.mu.Unlock()
p.client, p.resolverInfo = client, ri
}()
// Use UDP for DNSCrypt upstreams by default.
client = &dnscrypt.Client{Timeout: p.timeout, Net: networkUDP}
ri, err = client.Dial(addr)
if err != nil {
// Trigger client and server info renewal on the next request.
client, ri = nil, nil
err = fmt.Errorf("fetching certificate info from %s: %w", addr, err)
} else if p.verifyCert != nil {
err = p.verifyCert(ri.ResolverCert)
if err != nil {
// Trigger client and server info renewal on the next request.
client, ri = nil, nil
err = fmt.Errorf("verifying certificate info from %s: %w", addr, err)
}
}
return client, ri, err
}
07070100000095000081A4000000000000000000000001679A649F00001954000000000000000000000000000000000000003300000000dnsproxy-0.75.0/upstream/dnscrypt_internal_test.gopackage upstream
import (
"context"
"net"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/ameshkov/dnscrypt/v2"
"github.com/ameshkov/dnsstamps"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// dnsCryptHandlerFunc is a function-based implementation of the
// [dnscrypt.Handler] interface.
type dnsCryptHandlerFunc func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error)
// ServeDNS implements the [dnscrypt.Handler] interface for DNSCryptHandlerFunc.
func (f dnsCryptHandlerFunc) ServeDNS(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
return f(w, r)
}
// startTestDNSCryptServer starts a test DNSCrypt server with the specified
// resolver config and handler.
func startTestDNSCryptServer(
t testing.TB,
rc dnscrypt.ResolverConfig,
h dnscrypt.Handler,
) (stamp dnsstamps.ServerStamp) {
t.Helper()
cert, err := rc.CreateCert()
require.NoError(t, err)
s := &dnscrypt.Server{
ProviderName: rc.ProviderName,
ResolverCert: cert,
Handler: h,
}
testutil.CleanupAndRequireSuccess(t, func() (err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return s.Shutdown(ctx)
})
localhost := netutil.IPv4Localhost().AsSlice()
// Prepare TCP listener.
tcpAddr := &net.TCPAddr{IP: localhost, Port: 0}
tcpConn, err := net.ListenTCP("tcp", tcpAddr)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, tcpConn.Close)
// Prepare UDP listener on the same port.
port := testutil.RequireTypeAssert[*net.TCPAddr](t, tcpConn.Addr()).Port
udpAddr := &net.UDPAddr{IP: localhost, Port: port}
udpConn, err := net.ListenUDP("udp", udpAddr)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, udpConn.Close)
// Start the server.
go func() {
udpErr := s.ServeUDP(udpConn)
require.ErrorIs(testutil.PanicT{}, udpErr, net.ErrClosed)
}()
go func() {
tcpErr := s.ServeTCP(tcpConn)
require.NoError(testutil.PanicT{}, tcpErr)
}()
stamp, err = rc.CreateStamp(udpConn.LocalAddr().String())
require.NoError(t, err)
_, err = net.Dial("tcp", udpAddr.String())
require.NoError(t, err)
return stamp
}
func TestUpstreamDNSCrypt(t *testing.T) {
t.Parallel()
// AdGuard DNS (DNSCrypt)
address := "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"
u, err := AddressToUpstream(address, &Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: dialTimeout,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// Test that it responds properly
for range 10 {
checkUpstream(t, u, address)
}
}
func TestDNSCrypt_Exchange_truncated(t *testing.T) {
// Prepare the test DNSCrypt server config
rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
require.NoError(t, err)
var udpNum, tcpNum atomic.Uint32
h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
if w.RemoteAddr().Network() == networkUDP {
udpNum.Add(1)
} else {
tcpNum.Add(1)
}
res := (&dns.Msg{}).SetReply(r)
answer := &dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeTXT,
Ttl: 300,
Class: dns.ClassINET,
},
}
res.Answer = append(res.Answer, answer)
veryLongString := strings.Repeat("VERY LONG STRING", 7)
for range 50 {
answer.Txt = append(answer.Txt, veryLongString)
}
return w.WriteMsg(res)
})
srvStamp := startTestDNSCryptServer(t, rc, h)
u, err := AddressToUpstream(srvStamp.String(), &Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: timeout,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
req := (&dns.Msg{}).SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)
// Check that response is not truncated (even though it's huge).
res, err := u.Exchange(req)
require.NoError(t, err)
assert.False(t, res.Truncated)
assert.Equal(t, 1, int(udpNum.Load()))
assert.Equal(t, 1, int(tcpNum.Load()))
}
func TestDNSCrypt_Exchange_deadline(t *testing.T) {
t.Parallel()
// Prepare the test DNSCrypt server config
rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
require.NoError(t, err)
h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
return nil
})
srvStamp := startTestDNSCryptServer(t, rc, h)
// Use a shorter timeout to speed up the test.
u, err := AddressToUpstream(srvStamp.String(), &Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: 100 * time.Millisecond,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
req := (&dns.Msg{}).SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)
res, err := u.Exchange(req)
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
assert.Nil(t, res)
}
func TestDNSCrypt_Exchange_dialFail(t *testing.T) {
// Prepare the test DNSCrypt server config
rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
require.NoError(t, err)
h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
return nil
})
req := (&dns.Msg{}).SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)
var u Upstream
require.True(t, t.Run("run_and_shutdown", func(t *testing.T) {
srvStamp := startTestDNSCryptServer(t, rc, h)
// Use a shorter timeout to speed up the test.
u, err = AddressToUpstream(srvStamp.String(), &Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: 100 * time.Millisecond,
})
require.NoError(t, err)
}))
require.True(t, t.Run("dial_fail", func(t *testing.T) {
testutil.CleanupAndRequireSuccess(t, u.Close)
var res *dns.Msg
res, err = u.Exchange(req)
require.Error(t, err)
assert.Nil(t, res)
}))
t.Run("restart", func(t *testing.T) {
const validationErr errors.Error = "bad cert"
srvStamp := startTestDNSCryptServer(t, rc, h)
// Use a shorter timeout to speed up the test.
u, err = AddressToUpstream(srvStamp.String(), &Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: 100 * time.Millisecond,
VerifyDNSCryptCertificate: func(cert *dnscrypt.Cert) (err error) {
return validationErr
},
})
require.NoError(t, err)
var res *dns.Msg
res, err = u.Exchange(req)
require.ErrorIs(t, err, validationErr)
assert.Nil(t, res)
})
}
07070100000096000081A4000000000000000000000001679A649F0000544E000000000000000000000000000000000000002000000000dnsproxy-0.75.0/upstream/doh.gopackage upstream
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"golang.org/x/net/http2"
)
// Values to configure HTTP and HTTP/2 transport.
const (
// transportDefaultReadIdleTimeout is the default timeout for pinging
// idle connections in HTTP/2 transport.
transportDefaultReadIdleTimeout = 30 * time.Second
// transportDefaultIdleConnTimeout is the default timeout for idle
// connections in HTTP transport.
transportDefaultIdleConnTimeout = 5 * time.Minute
// dohMaxConnsPerHost controls the maximum number of connections for
// each host. Note, that setting it to 1 may cause issues with Go's http
// implementation, see https://github.com/AdguardTeam/dnsproxy/issues/278.
dohMaxConnsPerHost = 2
// dohMaxIdleConns controls the maximum number of connections being idle
// at the same time.
dohMaxIdleConns = 2
)
// dnsOverHTTPS is a struct that implements the Upstream interface for the
// DNS-over-HTTPS protocol.
type dnsOverHTTPS struct {
// getDialer either returns an initialized dial handler or creates a new
// one.
getDialer DialerInitializer
// addr is the DNS-over-HTTPS server URL.
addr *url.URL
// tlsConf is the configuration of TLS.
tlsConf *tls.Config
// The Client's Transport typically has internal state (cached TCP
// connections), so Clients should be reused instead of created as needed.
// Clients are safe for concurrent use by multiple goroutines.
client *http.Client
// clientMu protects client.
clientMu *sync.Mutex
// logger is used for exchange logging. It is never nil.
logger *slog.Logger
// quicConf is the QUIC configuration that is used if HTTP/3 is enabled
// for this upstream.
quicConf *quic.Config
// quicConfMu protects quicConf.
quicConfMu *sync.Mutex
// transportH2 is an HTTP/2 transport if any.
transportH2 *http2.Transport
// addrRedacted is the redacted string representation of addr. It is saved
// separately to reduce allocations during logging and error reporting.
addrRedacted string
// timeout is used in HTTP client and for H3 probes.
timeout time.Duration
}
// newDoH returns the DNS-over-HTTPS Upstream.
func newDoH(addr *url.URL, opts *Options) (u Upstream, err error) {
addPort(addr, defaultPortDoH)
var httpVersions []HTTPVersion
if addr.Scheme == "h3" {
addr.Scheme = "https"
httpVersions = []HTTPVersion{HTTPVersion3}
} else if httpVersions = opts.HTTPVersions; len(opts.HTTPVersions) == 0 {
httpVersions = DefaultHTTPVersions
}
ups := &dnsOverHTTPS{
getDialer: newDialerInitializer(addr, opts),
addr: addr,
quicConf: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
Tracer: opts.QUICTracer,
},
quicConfMu: &sync.Mutex{},
tlsConf: &tls.Config{
ServerName: addr.Hostname(),
RootCAs: opts.RootCAs,
CipherSuites: opts.CipherSuites,
// Use the default capacity for the LRU cache. It may be useful to
// store several caches since the user may be routed to different
// servers in case there's load balancing on the server-side.
ClientSessionCache: tls.NewLRUClientSessionCache(0),
MinVersion: tls.VersionTLS12,
// #nosec G402 -- TLS certificate verification could be disabled by
// configuration.
InsecureSkipVerify: opts.InsecureSkipVerify,
VerifyPeerCertificate: opts.VerifyServerCertificate,
VerifyConnection: opts.VerifyConnection,
},
clientMu: &sync.Mutex{},
logger: opts.Logger,
addrRedacted: addr.Redacted(),
timeout: opts.Timeout,
}
for _, v := range httpVersions {
ups.tlsConf.NextProtos = append(ups.tlsConf.NextProtos, string(v))
}
runtime.SetFinalizer(ups, (*dnsOverHTTPS).Close)
return ups, nil
}
// type check
var _ Upstream = (*dnsOverHTTPS)(nil)
// Address implements the [Upstream] interface for *dnsOverHTTPS. The address
// is redacted: if the original URL of this upstream contains a userinfo with a
// password, the password is replaced with "xxxxx".
func (p *dnsOverHTTPS) Address() string { return p.addrRedacted }
// Exchange implements the [Upstream] interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
// In order to maximize HTTP cache friendliness, DoH clients using media
// formats that include the ID field from the DNS message header, such as
// "application/dns-message", SHOULD use a DNS ID of 0 in every DNS request.
//
// See https://www.rfc-editor.org/rfc/rfc8484.html.
id := req.Id
req.Id = 0
defer func() {
// Restore the original ID to not break compatibility with proxies.
req.Id = id
if resp != nil {
resp.Id = id
}
}()
// Check if there was already an active client before sending the request.
// We'll only attempt to re-connect if there was one.
client, isCached, err := p.getClient()
if err != nil {
return nil, fmt.Errorf("failed to init http client: %w", err)
}
// Make the first attempt to send the DNS query.
resp, err = p.exchangeHTTPS(client, req)
// Make up to 2 attempts to re-create the HTTP client and send the request
// again. There are several cases (mostly, with QUIC) where this workaround
// is necessary to make HTTP client usable. We need to make 2 attempts in
// the case when the connection was closed (due to inactivity for example)
// AND the server refuses to open a 0-RTT connection.
for i := 0; isCached && p.shouldRetry(err) && i < 2; i++ {
client, err = p.resetClient(err)
if err != nil {
return nil, fmt.Errorf("failed to reset http client: %w", err)
}
resp, err = p.exchangeHTTPS(client, req)
}
if err != nil {
// If the request failed anyway, make sure we don't use this client.
_, resErr := p.resetClient(err)
return nil, errors.WithDeferred(err, resErr)
}
return resp, err
}
// Close implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Close() (err error) {
p.clientMu.Lock()
defer p.clientMu.Unlock()
runtime.SetFinalizer(p, nil)
if p.client != nil {
err = p.closeClient(p.client)
}
return err
}
// closeClient cleans up resources used by client if necessary. Note that this
// should be done for HTTP/3, as it can lead to resource leaks due to keep-alive
// connections, and for HTTP/2 due to idle connections.
func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
if isHTTP3(client) {
return client.Transport.(io.Closer).Close()
} else if p.transportH2 != nil {
p.transportH2.CloseIdleConnections()
}
return nil
}
// exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient.
func (p *dnsOverHTTPS) exchangeHTTPS(client *http.Client, req *dns.Msg) (resp *dns.Msg, err error) {
n := networkTCP
if isHTTP3(client) {
n = networkUDP
}
logBegin(p.logger, p.addrRedacted, n, req)
defer func() { logFinish(p.logger, p.addrRedacted, n, err) }()
return p.exchangeHTTPSClient(client, req)
}
// exchangeHTTPSClient sends the DNS query to a DoH resolver using the specified
// http.Client instance.
func (p *dnsOverHTTPS) exchangeHTTPSClient(
client *http.Client,
req *dns.Msg,
) (resp *dns.Msg, err error) {
buf, err := req.Pack()
if err != nil {
return nil, fmt.Errorf("packing message: %w", err)
}
// It appears, that GET requests are more memory-efficient with Golang
// implementation of HTTP/2.
method := http.MethodGet
if isHTTP3(client) {
// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
method = http3.MethodGet0RTT
}
q := url.Values{
"dns": []string{base64.RawURLEncoding.EncodeToString(buf)},
}
u := url.URL{
Scheme: p.addr.Scheme,
User: p.addr.User,
Host: p.addr.Host,
Path: p.addr.Path,
RawQuery: q.Encode(),
}
httpReq, err := http.NewRequest(method, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("creating http request to %s: %w", p.addrRedacted, err)
}
// Prevent the client from sending User-Agent header, see
// https://github.com/AdguardTeam/dnsproxy/issues/211.
httpReq.Header.Set(httphdr.UserAgent, "")
httpReq.Header.Set(httphdr.Accept, "application/dns-message")
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("requesting %s: %w", p.addrRedacted, err)
}
defer slogutil.CloseAndLog(httpReq.Context(), p.logger, httpResp.Body, slog.LevelDebug)
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, fmt.Errorf("reading %s: %w", p.addrRedacted, err)
}
if httpResp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(
"expected status %d, got %d from %s",
http.StatusOK,
httpResp.StatusCode,
p.addrRedacted,
)
}
resp = &dns.Msg{}
err = resp.Unpack(body)
if err != nil {
return nil, fmt.Errorf(
"unpacking response from %s: body is %s: %w",
p.addrRedacted,
body,
err,
)
}
if resp.Id != req.Id {
err = dns.ErrId
}
return resp, err
}
// shouldRetry checks what error we have received and returns true if we should
// re-create the HTTP client and retry the request.
func (p *dnsOverHTTPS) shouldRetry(err error) (ok bool) {
if err == nil {
return false
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
// If this is a timeout error, trying to forcibly re-create the HTTP
// client instance. This is an attempt to fix an issue with DoH client
// stalling after a network change.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3217.
return true
}
if isQUICRetryError(err) {
return true
}
return false
}
// resetClient triggers re-creation of the *http.Client that is used by this
// upstream. This method accepts the error that caused resetting client as
// depending on the error we may also reset the QUIC config.
func (p *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) {
p.clientMu.Lock()
defer p.clientMu.Unlock()
if errors.Is(resetErr, quic.Err0RTTRejected) {
// Reset the TokenStore only if 0-RTT was rejected.
p.resetQUICConfig()
}
oldClient := p.client
if oldClient != nil {
closeErr := p.closeClient(oldClient)
if closeErr != nil {
p.logger.Warn("failed to close the old http client", slogutil.KeyError, closeErr)
}
}
p.logger.Debug("recreating the http client", slogutil.KeyError, resetErr)
p.client, err = p.createClient()
return p.client, err
}
// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverHTTPS) getQUICConfig() (c *quic.Config) {
p.quicConfMu.Lock()
defer p.quicConfMu.Unlock()
return p.quicConf
}
// resetQUICConfig Re-create the token store to make sure we're not trying to
// use invalid for 0-RTT.
func (p *dnsOverHTTPS) resetQUICConfig() {
p.quicConfMu.Lock()
defer p.quicConfMu.Unlock()
p.quicConf = p.quicConf.Clone()
p.quicConf.TokenStore = newQUICTokenStore()
}
// getClient gets or lazily initializes an HTTP client (and transport) that will
// be used for this DoH resolver.
func (p *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) {
startTime := time.Now()
p.clientMu.Lock()
defer p.clientMu.Unlock()
if p.client != nil {
return p.client, true, nil
}
// Timeout can be exceeded while waiting for the lock. This happens quite
// often on mobile devices.
elapsed := time.Since(startTime)
if p.timeout > 0 && elapsed > p.timeout {
return nil, false, fmt.Errorf("timeout exceeded: %s", elapsed)
}
p.logger.Debug("creating a new http client")
p.client, err = p.createClient()
return p.client, false, err
}
// createClient creates a new *http.Client instance. The HTTP protocol version
// will depend on whether HTTP3 is allowed and provided by this upstream. Note,
// that we'll attempt to establish a QUIC connection when creating the client in
// order to check whether HTTP3 is supported.
func (p *dnsOverHTTPS) createClient() (*http.Client, error) {
transport, err := p.createTransport()
if err != nil {
return nil, fmt.Errorf("initializing http transport: %w", err)
}
client := &http.Client{
Transport: transport,
// TODO(ameshkov): p.timeout may appear zero that will disable the
// timeout for client, consider using the default.
Timeout: p.timeout,
Jar: nil,
}
p.client = client
return p.client, nil
}
// createTransport initializes an HTTP transport that will be used specifically
// for this DoH resolver. This HTTP transport ensures that the HTTP requests
// will be sent exactly to the IP address got from the bootstrap resolver. Note,
// that this function will first attempt to establish a QUIC connection (if
// HTTP3 is enabled in the upstream options). If this attempt is successful,
// it returns an HTTP3 transport, otherwise it returns the H1/H2 transport.
func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
dialContext, err := p.getDialer()
if err != nil {
return nil, fmt.Errorf("bootstrapping %s: %w", p.addrRedacted, err)
}
// First, we attempt to create an HTTP3 transport. If the probe QUIC
// connection is established successfully, we'll be using HTTP3 for this
// upstream.
tlsConf := p.tlsConf.Clone()
transportH3, err := p.createTransportH3(tlsConf, dialContext)
if err == nil {
p.logger.Debug("using http/3 for this upstream, quic was faster")
return transportH3, nil
}
p.logger.Debug("got error, switching to http/2 for this upstream", slogutil.KeyError, err)
if !p.supportsHTTP() {
return nil, errors.Error("HTTP1/1 and HTTP2 are not supported by this upstream")
}
transport := &http.Transport{
TLSClientConfig: tlsConf,
DisableCompression: true,
DialContext: dialContext,
IdleConnTimeout: transportDefaultIdleConnTimeout,
MaxConnsPerHost: dohMaxConnsPerHost,
MaxIdleConns: dohMaxIdleConns,
// Since we have a custom DialContext, we need to use this field to make
// golang http.Client attempt to use HTTP/2. Otherwise, it would only be
// used when negotiated on the TLS level.
ForceAttemptHTTP2: true,
}
// Explicitly configure transport to use HTTP/2.
//
// See https://github.com/AdguardTeam/dnsproxy/issues/11.
p.transportH2, err = http2.ConfigureTransports(transport)
if err != nil {
return nil, err
}
// Enable HTTP/2 pings on idle connections.
p.transportH2.ReadIdleTimeout = transportDefaultReadIdleTimeout
return transport, nil
}
// http3Transport is a wrapper over [*http3.Transport] that tries to optimize
// its behavior. The main thing that it does is trying to force use a single
// connection to a host instead of creating a new one all the time. It also
// helps mitigate race issues with quic-go.
type http3Transport struct {
baseTransport *http3.Transport
closed bool
mu sync.RWMutex
}
// type check
var _ http.RoundTripper = (*http3Transport)(nil)
// RoundTrip implements the http.RoundTripper interface for *http3Transport.
func (h *http3Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
h.mu.RLock()
defer h.mu.RUnlock()
if h.closed {
return nil, net.ErrClosed
}
// Try to use cached connection to the target host if it's available.
resp, err = h.baseTransport.RoundTripOpt(req, http3.RoundTripOpt{OnlyCachedConn: true})
if errors.Is(err, http3.ErrNoCachedConn) {
// If there are no cached connection, trigger creating a new one.
resp, err = h.baseTransport.RoundTrip(req)
}
return resp, err
}
// type check
var _ io.Closer = (*http3Transport)(nil)
// Close implements the io.Closer interface for *http3Transport.
func (h *http3Transport) Close() (err error) {
h.mu.Lock()
defer h.mu.Unlock()
h.closed = true
return h.baseTransport.Close()
}
// createTransportH3 tries to create an HTTP/3 transport for this upstream. We
// should be able to fall back to H1/H2 in case if HTTP/3 is unavailable or if
// it is too slow. In order to do that, this method will run two probes in
// parallel (one for TLS, the other one for QUIC) and if QUIC is faster it will
// create the [*http3.Transport] instance.
func (p *dnsOverHTTPS) createTransportH3(
tlsConfig *tls.Config,
dialContext bootstrap.DialHandler,
) (roundTripper http.RoundTripper, err error) {
if !p.supportsH3() {
return nil, errors.Error("HTTP3 support is not enabled")
}
addr, err := p.probeH3(tlsConfig, dialContext)
if err != nil {
return nil, err
}
rt := &http3.Transport{
Dial: func(
ctx context.Context,
// Ignore the address and always connect to the one that we got
// from the bootstrapper.
_ string,
tlsCfg *tls.Config,
cfg *quic.Config,
) (c quic.EarlyConnection, err error) {
c, err = quic.DialAddrEarly(ctx, addr, tlsCfg, cfg)
return c, err
},
DisableCompression: true,
TLSClientConfig: tlsConfig,
QUICConfig: p.getQUICConfig(),
}
return &http3Transport{baseTransport: rt}, nil
}
// probeH3 runs a test to check whether QUIC is faster than TLS for this
// upstream. If the test is successful it will return the address that we
// should use to establish the QUIC connections.
func (p *dnsOverHTTPS) probeH3(
tlsConfig *tls.Config,
dialContext bootstrap.DialHandler,
) (addr string, err error) {
// We're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there are v4/v6 addresses).
rawConn, err := dialContext(context.Background(), "udp", "")
if err != nil {
return "", fmt.Errorf("failed to dial: %w", err)
}
// It's never actually used.
_ = rawConn.Close()
udpConn, ok := rawConn.(*net.UDPConn)
if !ok {
return "", fmt.Errorf("not a UDP connection to %s", p.addrRedacted)
}
addr = udpConn.RemoteAddr().String()
// Avoid spending time on probing if this upstream only supports HTTP/3.
if p.supportsH3() && !p.supportsHTTP() {
return addr, nil
}
// Use a new *tls.Config with empty session cache for probe connections.
// Surprisingly, this is really important since otherwise it invalidates
// the existing cache.
// TODO(ameshkov): figure out why the sessions cache invalidates here.
probeTLSCfg := tlsConfig.Clone()
probeTLSCfg.ClientSessionCache = nil
// Do not expose probe connections to the callbacks that are passed to
// the bootstrap options to avoid side-effects.
// TODO(ameshkov): consider exposing, somehow mark that this is a probe.
probeTLSCfg.VerifyPeerCertificate = nil
probeTLSCfg.VerifyConnection = nil
// Run probeQUIC and probeTLS in parallel and see which one is faster.
chQUIC := make(chan error, 1)
chTLS := make(chan error, 1)
go p.probeQUIC(addr, probeTLSCfg, chQUIC)
go p.probeTLS(dialContext, probeTLSCfg, chTLS)
select {
case quicErr := <-chQUIC:
if quicErr != nil {
// QUIC failed, return error since HTTP3 was not preferred.
return "", quicErr
}
// Return immediately, QUIC was faster.
return addr, quicErr
case tlsErr := <-chTLS:
if tlsErr != nil {
// Return immediately, TLS failed.
p.logger.Debug("probing tls", slogutil.KeyError, tlsErr)
return addr, nil
}
return "", errors.Error("TLS was faster than QUIC, prefer it")
}
}
// probeQUIC attempts to establish a QUIC connection to the specified address.
// We run probeQUIC and probeTLS in parallel and see which one is faster.
func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()
t := p.timeout
if t == 0 {
t = dialTimeout
}
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(t))
defer cancel()
conn, err := quic.DialAddrEarly(ctx, addr, tlsConfig, p.getQUICConfig())
if err != nil {
ch <- fmt.Errorf("opening quic connection to %s: %w", p.addrRedacted, err)
return
}
// Ignore the error since there's no way we can use it for anything useful.
_ = conn.CloseWithError(QUICCodeNoError, "")
ch <- nil
elapsed := time.Since(startTime)
p.logger.Debug("quic connection established", "elapsed", elapsed)
}
// probeTLS attempts to establish a TLS connection to the specified address. We
// run probeQUIC and probeTLS in parallel and see which one is faster.
func (p *dnsOverHTTPS) probeTLS(dialContext bootstrap.DialHandler, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()
conn, err := tlsDial(dialContext, tlsConfig)
if err != nil {
ch <- fmt.Errorf("opening TLS connection: %w", err)
return
}
// Ignore the error since there's no way we can use it for anything useful.
_ = conn.Close()
ch <- nil
elapsed := time.Since(startTime)
p.logger.Debug("tls connection established", "elapsed", elapsed)
}
// supportsH3 returns true if HTTP/3 is supported by this upstream.
func (p *dnsOverHTTPS) supportsH3() (ok bool) {
for _, v := range p.tlsConf.NextProtos {
if v == string(HTTPVersion3) {
return true
}
}
return false
}
// supportsHTTP returns true if HTTP/1.1 or HTTP2 is supported by this upstream.
func (p *dnsOverHTTPS) supportsHTTP() (ok bool) {
for _, v := range p.tlsConf.NextProtos {
if v == string(HTTPVersion11) || v == string(HTTPVersion2) {
return true
}
}
return false
}
// isHTTP3 checks if the *http.Client is an HTTP/3 client.
func isHTTP3(client *http.Client) (ok bool) {
_, ok = client.Transport.(*http3Transport)
return ok
}
07070100000097000081A4000000000000000000000001679A649F0000375E000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/upstream/doh_internal_test.gopackage upstream
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/netip"
"net/url"
"sync/atomic"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/stretchr/testify/require"
)
func TestUpstreamDoH(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
expectedProtocol HTTPVersion
httpVersions []HTTPVersion
delayHandshakeH3 time.Duration
delayHandshakeH2 time.Duration
http3Enabled bool
}{{
name: "http1.1_h2",
http3Enabled: false,
httpVersions: []HTTPVersion{HTTPVersion11, HTTPVersion2},
expectedProtocol: HTTPVersion2,
}, {
name: "fallback_to_http2",
http3Enabled: false,
httpVersions: []HTTPVersion{HTTPVersion3, HTTPVersion2},
expectedProtocol: HTTPVersion2,
}, {
name: "http3",
http3Enabled: true,
httpVersions: []HTTPVersion{HTTPVersion3},
expectedProtocol: HTTPVersion3,
}, {
name: "race_http3_faster",
http3Enabled: true,
httpVersions: []HTTPVersion{HTTPVersion3, HTTPVersion2},
delayHandshakeH2: time.Second,
expectedProtocol: HTTPVersion3,
}, {
name: "race_http2_faster",
http3Enabled: true,
httpVersions: []HTTPVersion{HTTPVersion3, HTTPVersion2},
delayHandshakeH3: time.Second,
expectedProtocol: HTTPVersion2,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
srv := startDoHServer(t, testDoHServerOptions{
http3Enabled: tc.http3Enabled,
delayHandshakeH2: tc.delayHandshakeH2,
delayHandshakeH3: tc.delayHandshakeH3,
})
// Create a DNS-over-HTTPS upstream.
address := fmt.Sprintf("https://%s/dns-query", srv.addr)
var lastState tls.ConnectionState
opts := &Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: true,
HTTPVersions: tc.httpVersions,
VerifyConnection: func(state tls.ConnectionState) (err error) {
if state.NegotiatedProtocol != string(tc.expectedProtocol) {
return fmt.Errorf(
"expected %s, got %s",
tc.expectedProtocol,
state.NegotiatedProtocol,
)
}
lastState = state
return nil
},
}
u, err := AddressToUpstream(address, opts)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// Test that it responds properly.
for range 10 {
checkUpstream(t, u, address)
}
doh := u.(*dnsOverHTTPS)
// Trigger re-connection.
doh.client = nil
// Force it to establish the connection again.
checkUpstream(t, u, address)
// Check that TLS session was resumed properly.
require.True(t, lastState.DidResume)
})
}
}
func TestUpstreamDoH_raceReconnect(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
expectedProtocol HTTPVersion
httpVersions []HTTPVersion
delayHandshakeH3 time.Duration
delayHandshakeH2 time.Duration
http3Enabled bool
}{{
name: "http1.1_h2",
http3Enabled: false,
httpVersions: []HTTPVersion{HTTPVersion11, HTTPVersion2},
expectedProtocol: HTTPVersion2,
}, {
name: "fallback_to_http2",
http3Enabled: false,
httpVersions: []HTTPVersion{HTTPVersion3, HTTPVersion2},
expectedProtocol: HTTPVersion2,
}, {
name: "http3",
http3Enabled: true,
httpVersions: []HTTPVersion{HTTPVersion3},
expectedProtocol: HTTPVersion3,
}, {
name: "race_http3_faster",
http3Enabled: true,
httpVersions: []HTTPVersion{HTTPVersion3, HTTPVersion2},
delayHandshakeH2: time.Second,
expectedProtocol: HTTPVersion3,
}, {
name: "race_http2_faster",
http3Enabled: true,
httpVersions: []HTTPVersion{HTTPVersion3, HTTPVersion2},
delayHandshakeH3: time.Second,
expectedProtocol: HTTPVersion2,
}}
// This is a different set of tests that are supposed to be run with -race.
// The difference is that the HTTP handler here adds additional time.Sleep
// call. This call would trigger the HTTP client re-connection which is
// important to test for race conditions.
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
const timeout = time.Millisecond * 100
var requestsCount int32
handlerFunc := createDoHHandlerFunc()
mux := http.NewServeMux()
mux.HandleFunc("/dns-query", func(w http.ResponseWriter, r *http.Request) {
newVal := atomic.AddInt32(&requestsCount, 1)
if newVal%10 == 0 {
time.Sleep(timeout * 2)
}
handlerFunc(w, r)
})
srv := startDoHServer(t, testDoHServerOptions{
http3Enabled: tc.http3Enabled,
delayHandshakeH2: tc.delayHandshakeH2,
delayHandshakeH3: tc.delayHandshakeH3,
handler: mux,
})
// Create a DNS-over-HTTPS upstream that will be used for the
// race test.
address := fmt.Sprintf("https://%s/dns-query", srv.addr)
opts := &Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: true,
HTTPVersions: tc.httpVersions,
Timeout: timeout,
}
u, err := AddressToUpstream(address, opts)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
checkRaceCondition(u)
})
}
}
func TestUpstreamDoH_serverRestart(t *testing.T) {
testCases := []struct {
name string
httpVersions []HTTPVersion
}{{
name: "http2",
httpVersions: []HTTPVersion{HTTPVersion11, HTTPVersion2},
}, {
name: "http3",
httpVersions: []HTTPVersion{HTTPVersion3},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var addr netip.AddrPort
var upsAddr string
var u Upstream
t.Run("first_try", func(t *testing.T) {
srv := startDoHServer(t, testDoHServerOptions{
http3Enabled: true,
})
addr = netip.MustParseAddrPort(srv.addr)
upsAddr = (&url.URL{
Scheme: "https",
Host: addr.String(),
Path: "dns-query",
}).String()
var err error
u, err = AddressToUpstream(upsAddr, &Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: true,
HTTPVersions: tc.httpVersions,
Timeout: 100 * time.Millisecond,
})
require.NoError(t, err)
checkUpstream(t, u, upsAddr)
})
require.False(t, t.Failed())
testutil.CleanupAndRequireSuccess(t, u.Close)
t.Run("second_try", func(t *testing.T) {
_ = startDoHServer(t, testDoHServerOptions{
http3Enabled: true,
port: int(addr.Port()),
})
checkUpstream(t, u, upsAddr)
})
require.False(t, t.Failed())
t.Run("retry", func(t *testing.T) {
_, err := u.Exchange(createTestMessage())
require.Error(t, err)
_ = startDoHServer(t, testDoHServerOptions{
http3Enabled: true,
port: int(addr.Port()),
})
checkUpstream(t, u, upsAddr)
})
})
}
}
func TestUpstreamDoH_0RTT(t *testing.T) {
t.Parallel()
// Run the first server instance.
srv := startDoHServer(t, testDoHServerOptions{
http3Enabled: true,
})
// Create a DNS-over-HTTPS upstream.
tracer := &quicTracer{}
address := fmt.Sprintf("h3://%s/dns-query", srv.addr)
u, err := AddressToUpstream(address, &Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: true,
QUICTracer: tracer.TracerForConnection,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
uh := u.(*dnsOverHTTPS)
req := createTestMessage()
// Trigger connection to a DoH3 server.
resp, err := uh.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, resp)
// Close the active connection to make sure we'll reconnect.
func() {
uh.clientMu.Lock()
defer uh.clientMu.Unlock()
err = uh.closeClient(uh.client)
require.NoError(t, err)
uh.client = nil
}()
// Trigger second connection.
resp, err = uh.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, resp)
// Check traced connections info.
conns := tracer.getConnectionsInfo()
require.Len(t, conns, 2)
// Examine the first connection (no 0-RTT there).
require.False(t, conns[0].is0RTT())
// Examine the second connection (the one that used 0-RTT).
require.True(t, conns[1].is0RTT())
}
// testDoHServerOptions allows customizing testDoHServer behavior.
type testDoHServerOptions struct {
// handler is an HTTP handler that should be used by the server. The
// default one is used on nil.
handler http.Handler
// delayHandshakeH2 is a delay that should be added to the handshake of the
// HTTP/2 server.
delayHandshakeH2 time.Duration
// delayHandshakeH3 is a delay that should be added to the handshake of the
// HTTP/3 server.
delayHandshakeH3 time.Duration
// port is the port that the server should listen to. If it's 0, a random
// port is used.
port int
// http3Enabled is a flag that indicates whether the server should start an
// HTTP/3 server.
http3Enabled bool
}
// testDoHServer is an instance of a test DNS-over-HTTPS server.
type testDoHServer struct {
// tlsConfig is the TLS configuration that is used for this server.
tlsConfig *tls.Config
// rootCAs is the pool with root certificates used by the test server.
rootCAs *x509.CertPool
// server is an HTTP/1.1 and HTTP/2 server.
server *http.Server
// serverH3 is an HTTP/3 server.
serverH3 *http3.Server
// listenerH3 that's used to serve HTTP/3.
listenerH3 *quic.EarlyListener
// addr is the address that this server listens to.
addr string
}
// Shutdown stops the DoH server.
func (s *testDoHServer) Shutdown() {
if s.server != nil {
_ = s.server.Shutdown(context.Background())
}
if s.serverH3 != nil {
_ = s.serverH3.Close()
_ = s.listenerH3.Close()
}
}
// startDoHServer starts a new DNS-over-HTTPS server with specified options. It
// returns a started server instance with addr set. Note that it adds its own
// shutdown to cleanup of t.
func startDoHServer(
t *testing.T,
opts testDoHServerOptions,
) (s *testDoHServer) {
tlsConfig, rootCAs := createServerTLSConfig(t, "127.0.0.1")
handler := opts.handler
if handler == nil {
handler = createDoHHandler()
}
// Step one is to create a regular HTTP server, we'll always have it
// running.
server := &http.Server{
Handler: handler,
ReadTimeout: time.Second,
}
// Listen TCP first.
listenAddr := fmt.Sprintf("127.0.0.1:%d", opts.port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
require.NoError(t, err)
tcpListen, err := net.ListenTCP("tcp", tcpAddr)
require.NoError(t, err)
tlsConfigH2 := tlsConfig.Clone()
tlsConfigH2.NextProtos = []string{string(HTTPVersion2), string(HTTPVersion11)}
tlsConfigH2.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
if opts.delayHandshakeH2 > 0 {
time.Sleep(opts.delayHandshakeH2)
}
return nil, nil
}
tlsListen := tls.NewListener(tcpListen, tlsConfigH2)
// Run the H1/H2 server.
go func() {
// TODO(ameshkov): check the error here.
_ = server.Serve(tlsListen)
}()
// Get the real address that the listener now listens to.
tcpAddr = tcpListen.Addr().(*net.TCPAddr)
var serverH3 *http3.Server
var listenerH3 *quic.EarlyListener
if opts.http3Enabled {
tlsConfigH3 := tlsConfig.Clone()
tlsConfigH3.NextProtos = []string{string(HTTPVersion3)}
tlsConfigH3.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
if opts.delayHandshakeH3 > 0 {
time.Sleep(opts.delayHandshakeH3)
}
return nil, nil
}
serverH3 = &http3.Server{
Handler: handler,
}
// Listen UDP for the H3 server. Reuse the same port as was used for the
// TCP listener.
var udpAddr *net.UDPAddr
udpAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port))
require.NoError(t, err)
var conn net.PacketConn
conn, err = net.ListenUDP("udp", udpAddr)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
transport := &quic.Transport{
Conn: conn,
VerifySourceAddress: func(net.Addr) bool { return false },
}
// QUIC configuration with the 0-RTT support enabled by default.
listenerH3, err = transport.ListenEarly(tlsConfigH3, &quic.Config{
Allow0RTT: true,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, transport.Close)
// Run the H3 server.
go func() {
// TODO(ameshkov): check the error here.
_ = serverH3.ServeListener(listenerH3)
}()
}
s = &testDoHServer{
tlsConfig: tlsConfig,
rootCAs: rootCAs,
server: server,
serverH3: serverH3,
listenerH3: listenerH3,
// Save the address that the server listens to.
addr: tcpAddr.String(),
}
t.Cleanup(s.Shutdown)
return s
}
// createDoHHandlerFunc creates a simple http.HandlerFunc that reads the
// incoming DNS message and returns the test response.
func createDoHHandlerFunc() (f http.HandlerFunc) {
return func(w http.ResponseWriter, r *http.Request) {
dnsParam := r.URL.Query().Get("dns")
buf, err := base64.RawURLEncoding.DecodeString(dnsParam)
if err != nil {
http.Error(
w,
fmt.Sprintf("internal error: %s", err),
http.StatusInternalServerError,
)
return
}
m := &dns.Msg{}
err = m.Unpack(buf)
if err != nil {
http.Error(
w,
fmt.Sprintf("internal error: %s", err),
http.StatusInternalServerError,
)
return
}
resp := respondToTestMessage(m)
buf, err = resp.Pack()
if err != nil {
http.Error(
w,
fmt.Sprintf("internal error: %s", err),
http.StatusInternalServerError,
)
return
}
w.Header().Set("Content-Type", "application/dns-message")
_, err = w.Write(buf)
if err != nil {
panic(fmt.Errorf("unexpected error on writing response: %w", err))
}
}
}
// createDoHHandler returns a very simple http.Handler that reads the incoming
// request and returns with a test message.
func createDoHHandler() (h http.Handler) {
mux := http.NewServeMux()
mux.HandleFunc("/dns-query", createDoHHandlerFunc())
return mux
}
07070100000098000081A4000000000000000000000001679A649F00003DB0000000000000000000000000000000000000002000000000dnsproxy-0.75.0/upstream/doq.gopackage upstream
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"net/url"
"os"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
)
const (
// QUICCodeNoError is used when the connection or stream needs to be closed,
// but there is no error to signal.
QUICCodeNoError = quic.ApplicationErrorCode(0)
// QUICCodeInternalError signals that the DoQ implementation encountered
// an internal error and is incapable of pursuing the transaction or the
// connection.
QUICCodeInternalError = quic.ApplicationErrorCode(1)
// QUICKeepAlivePeriod is the value that we pass to *quic.Config and that
// controls the period with with keep-alive frames are being sent to the
// connection. We set it to 20s as it would be in the quic-go@v0.27.1 with
// KeepAlive field set to true This value is specified in
// https://pkg.go.dev/github.com/quic-go/quic-go/internal/protocol#MaxKeepAliveInterval.
//
// TODO(ameshkov): Consider making it configurable.
QUICKeepAlivePeriod = time.Second * 20
// NextProtoDQ is the ALPN token for DoQ. During the connection establishment,
// DNS/QUIC support is indicated by selecting the ALPN token "doq" in the
// crypto handshake.
//
// See https://datatracker.ietf.org/doc/rfc9250.
NextProtoDQ = "doq"
)
// compatProtoDQ is a list of ALPN tokens used by a QUIC connection.
// NextProtoDQ is the latest draft version supported by dnsproxy, but it also
// includes previous drafts.
var compatProtoDQ = []string{NextProtoDQ, "doq-i00", "dq", "doq-i02"}
// dnsOverQUIC implements the [Upstream] interface for the DNS-over-QUIC
// protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html).
type dnsOverQUIC struct {
// getDialer either returns an initialized dial handler or creates a new
// one.
getDialer DialerInitializer
// addr is the DNS-over-QUIC server URL.
addr *url.URL
// tlsConf is the configuration of TLS.
tlsConf *tls.Config
// quicConfig is the QUIC configuration that is used for establishing
// connections to the upstream. This configuration includes the TokenStore
// that needs to be stored for the lifetime of dnsOverQUIC since we can
// re-create the connection.
quicConfig *quic.Config
// conn is the current active QUIC connection. It can be closed and
// re-opened when needed.
conn quic.Connection
// bytesPool is a *sync.Pool we use to store byte buffers in. These byte
// buffers are used to read responses from the upstream.
bytesPool *sync.Pool
// quicConfigMu protects quicConfig.
quicConfigMu *sync.Mutex
// connMu protects conn.
connMu *sync.Mutex
// bytesPoolGuard protects bytesPool.
bytesPoolMu *sync.Mutex
// logger is used for exchange logging. It is never nil.
logger *slog.Logger
// timeout is the timeout for the upstream connection.
timeout time.Duration
}
// newDoQ returns the DNS-over-QUIC Upstream.
func newDoQ(addr *url.URL, opts *Options) (u Upstream, err error) {
addPort(addr, defaultPortDoQ)
u = &dnsOverQUIC{
getDialer: newDialerInitializer(addr, opts),
addr: addr,
quicConfig: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
Tracer: opts.QUICTracer,
},
tlsConf: &tls.Config{
ServerName: addr.Hostname(),
RootCAs: opts.RootCAs,
CipherSuites: opts.CipherSuites,
// Use the default capacity for the LRU cache. It may be useful to
// store several caches since the user may be routed to different
// servers in case there's load balancing on the server-side.
ClientSessionCache: tls.NewLRUClientSessionCache(0),
MinVersion: tls.VersionTLS12,
// #nosec G402 -- TLS certificate verification could be disabled by
// configuration.
InsecureSkipVerify: opts.InsecureSkipVerify,
VerifyPeerCertificate: opts.VerifyServerCertificate,
VerifyConnection: opts.VerifyConnection,
NextProtos: compatProtoDQ,
},
quicConfigMu: &sync.Mutex{},
connMu: &sync.Mutex{},
bytesPoolMu: &sync.Mutex{},
logger: opts.Logger,
timeout: opts.Timeout,
}
runtime.SetFinalizer(u, (*dnsOverQUIC).Close)
return u, nil
}
// type check
var _ Upstream = (*dnsOverQUIC)(nil)
// Address implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Address() string { return p.addr.String() }
// Exchange implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
// When sending queries over a QUIC connection, the DNS Message ID MUST be
// set to 0. The stream mapping for DoQ allows for unambiguous correlation
// of queries and responses, so the Message ID field is not required.
//
// See https://www.rfc-editor.org/rfc/rfc9250#section-4.2.1.
id := req.Id
req.Id = 0
defer func() {
// Restore the original ID to not break compatibility with proxies.
req.Id = id
if resp != nil {
resp.Id = id
}
}()
// Gets or opens a QUIC connection to use for this query.
conn, cached, err := p.getConnection()
if err != nil {
return nil, fmt.Errorf("getting conn: %w", err)
}
// Make the first attempt to send the DNS query.
resp, err = p.exchangeQUIC(req, conn)
// Failure to use a cached connection should be handled gracefully as this
// connection could have been closed by the server or simply be broken due
// to how UDP NAT works. In this case the connection should be re-created.
if cached && err != nil {
p.logger.Debug("recreating the quic connection and retrying", slogutil.KeyError, err)
// Close the active connection to make sure the cached connection is
// cleaned up.
p.closeConnWithError(conn, err)
// Get or re-create the QUIC connection in order to make the second
// attempt.
conn, _, err = p.getConnection()
if err != nil {
return nil, fmt.Errorf("getting new conn: %w", err)
}
// Retry sending the request through the new connection.
resp, err = p.exchangeQUIC(req, conn)
}
if err != nil {
// If we're unable to exchange messages, make sure the connection is
// closed and signal about an internal error.
p.closeConnWithError(conn, err)
}
return resp, err
}
// Close implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Close() (err error) {
p.connMu.Lock()
defer p.connMu.Unlock()
runtime.SetFinalizer(p, nil)
if p.conn != nil {
err = p.conn.CloseWithError(QUICCodeNoError, "")
}
return err
}
// exchangeQUIC attempts to open a new QUIC stream, send the DNS message
// through it and return the response it got from the server.
func (p *dnsOverQUIC) exchangeQUIC(req *dns.Msg, conn quic.Connection) (resp *dns.Msg, err error) {
addr := p.Address()
logBegin(p.logger, addr, networkUDP, req)
defer func() { logFinish(p.logger, addr, networkUDP, err) }()
buf, err := req.Pack()
if err != nil {
return nil, fmt.Errorf("failed to pack DNS message for DoQ: %w", err)
}
stream, err := p.openStream(conn)
if err != nil {
return nil, fmt.Errorf("opening stream: %w", err)
}
if p.timeout > 0 {
err = stream.SetDeadline(time.Now().Add(p.timeout))
if err != nil {
return nil, fmt.Errorf("setting deadline: %w", err)
}
}
_, err = stream.Write(proxyutil.AddPrefix(buf))
if err != nil {
return nil, fmt.Errorf("failed to write to a QUIC stream: %w", err)
}
// The client MUST send the DNS query over the selected stream, and MUST
// indicate through the STREAM FIN mechanism that no further data will be
// sent on that stream. Note, that stream.Close() closes the write-direction
// of the stream, but does not prevent reading from it.
err = stream.Close()
if err != nil {
p.logger.Debug("closing quic stream", slogutil.KeyError, err)
}
return p.readMsg(stream)
}
// getBytesPool returns (creates if needed) a pool we store byte buffers in.
func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) {
p.bytesPoolMu.Lock()
defer p.bytesPoolMu.Unlock()
if p.bytesPool == nil {
p.bytesPool = &sync.Pool{
New: func() interface{} {
b := make([]byte, dns.MaxMsgSize)
return &b
},
}
}
return p.bytesPool
}
// getConnection opens or returns an existing quic.Connection and indicates
// whether it opened a new connection or used an existing cached one.
func (p *dnsOverQUIC) getConnection() (conn quic.Connection, cached bool, err error) {
p.connMu.Lock()
defer p.connMu.Unlock()
conn = p.conn
if conn != nil {
return conn, true, nil
}
conn, err = p.openConnection()
if err != nil {
return nil, false, err
}
p.conn = conn
return conn, false, nil
}
// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverQUIC) getQUICConfig() (c *quic.Config) {
p.quicConfigMu.Lock()
defer p.quicConfigMu.Unlock()
return p.quicConfig
}
// resetQUICConfig re-creates the tokens store as we may need to use a new one
// if we failed to connect.
func (p *dnsOverQUIC) resetQUICConfig() {
p.quicConfigMu.Lock()
defer p.quicConfigMu.Unlock()
p.quicConfig = p.quicConfig.Clone()
p.quicConfig.TokenStore = newQUICTokenStore()
}
// openStream opens a new QUIC stream for the specified connection.
func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) {
ctx, cancel := p.withDeadline(context.Background())
defer cancel()
stream, err := conn.OpenStreamSync(ctx)
if err != nil {
return nil, fmt.Errorf("failed to open a QUIC stream: %w", err)
}
return stream, nil
}
// openConnection dials a new QUIC connection.
func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
dialContext, err := p.getDialer()
if err != nil {
return nil, fmt.Errorf("bootstrapping %s: %w", p.addr, err)
}
// we're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there're v4/v6 addresses).
rawConn, err := dialContext(context.Background(), "udp", "")
if err != nil {
return nil, fmt.Errorf("dialing raw connection to %s: %w", p.addr, err)
}
// It's never actually used.
err = rawConn.Close()
if err != nil {
p.logger.Debug("closing raw connection", "addr", p.addr, slogutil.KeyError, err)
}
udpConn, ok := rawConn.(*net.UDPConn)
if !ok {
return nil, fmt.Errorf("unexpected type %T of connection; should be %T", rawConn, udpConn)
}
addr := udpConn.RemoteAddr().String()
ctx, cancel := p.withDeadline(context.Background())
defer cancel()
conn, err = quic.DialAddrEarly(ctx, addr, p.tlsConf.Clone(), p.getQUICConfig())
if err != nil {
return nil, fmt.Errorf("dialing quic connection to %s: %w", p.addr, err)
}
return conn, nil
}
// closeConnWithError closes the active connection with error to make sure that
// new queries were processed in another connection. We can do that in the case
// of a fatal error.
func (p *dnsOverQUIC) closeConnWithError(conn quic.Connection, err error) {
p.connMu.Lock()
defer p.connMu.Unlock()
code := QUICCodeNoError
if err != nil {
code = QUICCodeInternalError
}
if errors.Is(err, quic.Err0RTTRejected) {
// Reset the TokenStore only if 0-RTT was rejected.
p.resetQUICConfig()
}
err = conn.CloseWithError(code, "")
if err != nil {
p.logger.Error("failed to close the conn", slogutil.KeyError, err)
}
// If the connection that's being closed is cached, reset the cache.
if p.conn == conn {
p.conn = nil
}
}
// readMsg reads the incoming DNS message from the QUIC stream.
func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *dns.Msg, err error) {
pool := p.getBytesPool()
bufPtr := pool.Get().(*[]byte)
defer pool.Put(bufPtr)
respBuf := *bufPtr
n, err := stream.Read(respBuf)
if err != nil && n == 0 {
return nil, fmt.Errorf("reading response from %s: %w", p.addr, err)
}
stream.CancelRead(0)
// All DNS messages (queries and responses) sent over DoQ connections MUST
// be encoded as a 2-octet length field followed by the message content as
// specified in [RFC1035].
// IMPORTANT: Note, that we ignore this prefix here as this implementation
// does not support receiving multiple messages over a single connection.
m = new(dns.Msg)
err = m.Unpack(respBuf[2:])
if err != nil {
return nil, fmt.Errorf("unpacking response from %s: %w", p.addr, err)
}
return m, nil
}
// newQUICTokenStore creates a new quic.TokenStore that is necessary to have
// in order to benefit from 0-RTT.
func newQUICTokenStore() (s quic.TokenStore) {
// You can read more on address validation here:
// https://datatracker.ietf.org/doc/html/rfc9000#section-8.1
// Setting maxOrigins to 1 and tokensPerOrigin to 10 assuming that this is
// more than enough for the way we use it (one connection per upstream).
return quic.NewLRUTokenStore(1, 10)
}
// isQUICRetryError checks the error and determines whether it may signal that
// we should re-create the QUIC connection. This requirement is caused by
// quic-go issues, see the comments inside this function.
// TODO(ameshkov): re-test when updating quic-go.
func isQUICRetryError(err error) (ok bool) {
var qAppErr *quic.ApplicationError
if errors.As(err, &qAppErr) {
// Error code 0 is often returned when the server has been restarted,
// and we try to use the same connection on the client-side.
// http3.ErrCodeNoError may be used by an HTTP/3 server when closing
// an idle connection. These connections are not immediately closed
// by the HTTP client so this case should be handled.
if qAppErr.ErrorCode == 0 ||
qAppErr.ErrorCode == quic.ApplicationErrorCode(http3.ErrCodeNoError) {
return true
}
}
var qIdleErr *quic.IdleTimeoutError
if errors.As(err, &qIdleErr) {
// This error means that the connection was closed due to being idle.
// In this case we should forcibly re-create the QUIC connection.
// Reproducing is rather simple, stop the server and wait for 30 seconds
// then try to send another request via the same upstream.
return true
}
var resetErr *quic.StatelessResetError
if errors.As(err, &resetErr) {
// A stateless reset is sent when a server receives a QUIC packet that
// it doesn't know how to decrypt. For instance, it may happen when
// the server was recently rebooted. We should reconnect and try again
// in this case.
return true
}
var qTransportError *quic.TransportError
if errors.As(err, &qTransportError) && qTransportError.ErrorCode == quic.NoError {
// A transport error with the NO_ERROR error code could be sent by the
// server when it considers that it's time to close the connection.
// For example, Google DNS eventually closes an active connection with
// the NO_ERROR code and "Connection max age expired" message:
// https://github.com/AdguardTeam/dnsproxy/issues/283
return true
}
if errors.Is(err, quic.Err0RTTRejected) {
// This error happens when we try to establish a 0-RTT connection with
// a token the server is no more aware of. This can be reproduced by
// restarting the QUIC server (it will clear its tokens cache). The
// next connection attempt will return this error until the client's
// tokens cache is purged.
return true
}
if errors.Is(err, os.ErrDeadlineExceeded) {
// A timeout that could happen when the server has been restarted.
return true
}
return false
}
func (p *dnsOverQUIC) withDeadline(
parent context.Context,
) (ctx context.Context, cancel context.CancelFunc) {
ctx, cancel = parent, func() {}
if p.timeout > 0 {
ctx, cancel = context.WithDeadline(ctx, time.Now().Add(p.timeout))
}
return ctx, cancel
}
07070100000099000081A4000000000000000000000001679A649F00002EC1000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/upstream/doq_internal_test.gopackage upstream
import (
"context"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
"net/url"
"sync"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUpstreamDoQ(t *testing.T) {
tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")
srv := startDoQServer(t, tlsConf, 0)
address := fmt.Sprintf("quic://%s", srv.addr)
var lastState tls.ConnectionState
opts := &Options{
Logger: slogutil.NewDiscardLogger(),
VerifyConnection: func(state tls.ConnectionState) error {
lastState = state
return nil
},
RootCAs: rootCAs,
}
u, err := AddressToUpstream(address, opts)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
uq := u.(*dnsOverQUIC)
var conn quic.Connection
// Test that it responds properly
for range 10 {
checkUpstream(t, u, address)
if conn == nil {
conn = uq.conn
} else {
// This way we test that the connection is properly reused.
require.Equal(t, conn, uq.conn)
}
}
// Close the connection (make sure that we re-establish the connection).
_ = conn.CloseWithError(quic.ApplicationErrorCode(0), "")
// Try to establish it again.
checkUpstream(t, u, address)
// Make sure that the session has been resumed.
require.True(t, lastState.DidResume)
// Re-create the upstream to make the test check initialization and
// check it for race conditions.
u, err = AddressToUpstream(address, opts)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
checkRaceCondition(u)
}
func TestUpstream_Exchange_quicServerCloseConn(t *testing.T) {
// Use the same tlsConf for all servers to preserve the data necessary for
// 0-RTT connections.
tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")
// Run the first server instance.
srv := startDoQServer(t, tlsConf, 0)
// Create a DNS-over-QUIC upstream.
address := fmt.Sprintf("quic://%s", srv.addr)
u, err := AddressToUpstream(address, &Options{
Logger: slogutil.NewDiscardLogger(),
RootCAs: rootCAs,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// Test that the upstream works properly.
checkUpstream(t, u, address)
// Close all active connections.
err = srv.closeConns()
require.NoError(t, err)
// Now run several queries in parallel to check that the error from the
// following issue is not happening:
// https://github.com/AdguardTeam/dnsproxy/issues/389.
//
// Run 10 queries in parallel as the initial testing showed that this is
// enough to trigger the race issue.
const parallelQueries = 10
wg := sync.WaitGroup{}
wg.Add(parallelQueries)
for i := 0; i < 10; i++ {
pt := testutil.PanicT{}
go func(t assert.TestingT) {
defer wg.Done()
req := createTestMessage()
_, errExch := u.Exchange(req)
assert.NoError(t, errExch)
}(pt)
}
wg.Wait()
}
func TestUpstreamDoQ_serverRestart(t *testing.T) {
t.Parallel()
// Use the same tlsConf for all servers to preserve the data necessary for
// 0-RTT connections.
tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")
var addr netip.AddrPort
var upsStr string
var u Upstream
t.Run("first_try", func(t *testing.T) {
srv := startDoQServer(t, tlsConf, 0)
addr = netip.MustParseAddrPort(srv.addr)
upsStr = (&url.URL{
Scheme: "quic",
Host: addr.String(),
}).String()
var err error
u, err = AddressToUpstream(
upsStr,
&Options{
Logger: slogutil.NewDiscardLogger(),
RootCAs: rootCAs,
Timeout: 100 * time.Millisecond,
},
)
require.NoError(t, err)
checkUpstream(t, u, upsStr)
})
require.False(t, t.Failed())
testutil.CleanupAndRequireSuccess(t, u.Close)
t.Run("second_try", func(t *testing.T) {
_ = startDoQServer(t, tlsConf, int(addr.Port()))
checkUpstream(t, u, upsStr)
})
require.False(t, t.Failed())
t.Run("retry", func(t *testing.T) {
_, err := u.Exchange(createTestMessage())
require.Error(t, err)
_ = startDoQServer(t, tlsConf, int(addr.Port()))
checkUpstream(t, u, upsStr)
})
}
func TestUpstreamDoQ_0RTT(t *testing.T) {
tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")
srv := startDoQServer(t, tlsConf, 0)
tracer := &quicTracer{}
address := fmt.Sprintf("quic://%s", srv.addr)
u, err := AddressToUpstream(address, &Options{
Logger: slogutil.NewDiscardLogger(),
QUICTracer: tracer.TracerForConnection,
RootCAs: rootCAs,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
uq := u.(*dnsOverQUIC)
req := createTestMessage()
// Trigger connection to a QUIC server.
resp, err := uq.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, resp)
// Close the active connection to make sure we'll reconnect.
func() {
uq.connMu.Lock()
defer uq.connMu.Unlock()
err = uq.conn.CloseWithError(QUICCodeNoError, "")
require.NoError(t, err)
uq.conn = nil
}()
// Trigger second connection.
resp, err = uq.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, resp)
// Check traced connections info.
conns := tracer.getConnectionsInfo()
require.Len(t, conns, 2)
// Examine the first connection (no 0-RTT there).
require.False(t, conns[0].is0RTT())
// Examine the second connection (the one that used 0-RTT).
require.True(t, conns[1].is0RTT())
}
// testDoHServer is an instance of a test DNS-over-QUIC server.
type testDoQServer struct {
// listener is the QUIC connections listener.
listener *quic.EarlyListener
// logger is used for serving errors logging.
logger *slog.Logger
// conns is the list of connections that are currently active.
conns map[quic.EarlyConnection]struct{}
// connsMu protects conns.
connsMu *sync.Mutex
// addr is the address that this server listens to.
addr string
}
// Shutdown stops the test server.
func (s *testDoQServer) Shutdown() (err error) {
errConns := s.closeConns()
errListener := s.listener.Close()
return errors.Join(errConns, errListener)
}
// Serve serves DoQ requests.
func (s *testDoQServer) Serve() {
for {
var conn quic.EarlyConnection
var err error
func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err = s.listener.Accept(ctx)
}()
if err != nil {
if errors.Is(err, quic.ErrServerClosed) {
s.logger.Debug("accept failed", slogutil.KeyError, err)
} else {
s.logger.Error("accept failed", slogutil.KeyError, err)
}
return
}
go s.handleQUICConnection(conn)
}
}
// handleQUICConnection handles incoming QUIC connection.
func (s *testDoQServer) handleQUICConnection(conn quic.EarlyConnection) {
s.addConn(conn)
defer s.closeConn(conn)
for {
ctx := context.Background()
stream, err := conn.AcceptStream(ctx)
if err != nil {
return
}
go func() {
qErr := s.handleQUICStream(ctx, stream)
if qErr != nil {
s.logger.Error("handling", "raddr", conn.RemoteAddr(), slogutil.KeyError, qErr)
_ = conn.CloseWithError(QUICCodeNoError, "")
}
}()
}
}
// handleQUICStream handles new QUIC streams, reads DNS messages and responds to
// them.
func (s *testDoQServer) handleQUICStream(ctx context.Context, stream quic.Stream) (err error) {
defer slogutil.CloseAndLog(ctx, s.logger, stream, slog.LevelDebug)
buf := make([]byte, dns.MaxMsgSize+2)
_, err = stream.Read(buf)
if err != nil && err != io.EOF {
return err
}
stream.CancelRead(0)
req := &dns.Msg{}
packetLen := binary.BigEndian.Uint16(buf[:2])
err = req.Unpack(buf[2 : packetLen+2])
if err != nil {
return err
}
resp := respondToTestMessage(req)
buf, err = resp.Pack()
if err != nil {
return err
}
buf = proxyutil.AddPrefix(buf)
_, err = stream.Write(buf)
return err
}
// addConn adds conn to the list of active connections.
func (s *testDoQServer) addConn(conn quic.EarlyConnection) {
s.connsMu.Lock()
defer s.connsMu.Unlock()
s.conns[conn] = struct{}{}
}
// closeConn closes the specified QUIC connection.
func (s *testDoQServer) closeConn(conn quic.EarlyConnection) {
s.connsMu.Lock()
defer s.connsMu.Unlock()
err := conn.CloseWithError(QUICCodeNoError, "")
if err != nil {
s.logger.Debug("failed to close conn", slogutil.KeyError, err)
}
delete(s.conns, conn)
}
// closeConns closes all active connections.
func (s *testDoQServer) closeConns() (err error) {
s.connsMu.Lock()
defer s.connsMu.Unlock()
var errs []error
for conn := range s.conns {
errConn := conn.CloseWithError(QUICCodeNoError, "")
if errConn != nil {
errs = append(errs, errConn)
}
delete(s.conns, conn)
}
return errors.Join(errs...)
}
// startDoQServer starts a test DoQ server. Note that it adds its own shutdown
// to cleanup of t.
func startDoQServer(t *testing.T, tlsConf *tls.Config, port int) (s *testDoQServer) {
tlsConf.NextProtos = []string{NextProtoDQ}
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", port))
require.NoError(t, err)
conn, err := net.ListenUDP("udp", udpAddr)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
transport := &quic.Transport{
Conn: conn,
// Necessary for 0-RTT.
VerifySourceAddress: func(a net.Addr) bool {
return true
},
}
listen, err := transport.ListenEarly(
tlsConf,
&quic.Config{
Allow0RTT: true,
},
)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, transport.Close)
s = &testDoQServer{
addr: listen.Addr().String(),
listener: listen,
// TODO(d.kolyshev): Add a concurrent safe [slog.Handler] wrapper for
// [testing.TB] log function.
logger: slogutil.NewDiscardLogger(),
conns: map[quic.EarlyConnection]struct{}{},
connsMu: &sync.Mutex{},
}
go s.Serve()
testutil.CleanupAndRequireSuccess(t, s.Shutdown)
return s
}
// quicTracer implements the logging.Tracer interface.
type quicTracer struct {
tracers []*quicConnTracer
// mu protects fields of *quicTracer and also protects fields of every
// nested *quicConnTracer.
mu sync.Mutex
}
// TracerForConnection implements the logging.Tracer interface for *quicTracer.
func (q *quicTracer) TracerForConnection(
_ context.Context,
_ logging.Perspective,
odcid logging.ConnectionID,
) (connTracer *logging.ConnectionTracer) {
q.mu.Lock()
defer q.mu.Unlock()
tracer := &quicConnTracer{id: odcid, parent: q}
q.tracers = append(q.tracers, tracer)
return &logging.ConnectionTracer{
SentLongHeaderPacket: tracer.SentLongHeaderPacket,
}
}
// connInfo contains information about packets that we've logged.
type connInfo struct {
packets []logging.Header
id logging.ConnectionID
}
// is0RTT returns true if this connection's packets contain 0-RTT packets.
func (c *connInfo) is0RTT() (ok bool) {
for _, packet := range c.packets {
hdr := packet
packetType := logging.PacketTypeFromHeader(&hdr)
if packetType == logging.PacketType0RTT {
return true
}
}
return false
}
// getConnectionsInfo returns the traced connections' information.
func (q *quicTracer) getConnectionsInfo() (conns []connInfo) {
q.mu.Lock()
defer q.mu.Unlock()
for _, tracer := range q.tracers {
conns = append(conns, connInfo{
id: tracer.id,
packets: tracer.packets,
})
}
return conns
}
// quicConnTracer implements the logging.ConnectionTracer interface.
type quicConnTracer struct {
parent *quicTracer
packets []logging.Header
id logging.ConnectionID
}
// SentLongHeaderPacket implements the logging.ConnectionTracer interface for
// *quicConnTracer.
func (q *quicConnTracer) SentLongHeaderPacket(
hdr *logging.ExtendedHeader,
_ logging.ByteCount,
_ logging.ECN,
_ *logging.AckFrame,
_ []logging.Frame,
) {
q.parent.mu.Lock()
defer q.parent.mu.Unlock()
q.packets = append(q.packets, hdr.Header)
}
0707010000009A000081A4000000000000000000000001679A649F00001BA3000000000000000000000000000000000000002000000000dnsproxy-0.75.0/upstream/dot.gopackage upstream
import (
"context"
"crypto/tls"
"fmt"
"io"
"log/slog"
"net"
"net/url"
"os"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
)
// dialTimeout is the global timeout for establishing a TLS connection.
// TODO(ameshkov): use bootstrap timeout instead.
const dialTimeout = 10 * time.Second
// dnsOverTLS implements the [Upstream] interface for the DNS-over-TLS protocol.
type dnsOverTLS struct {
// addr is the DNS-over-TLS server URL.
addr *url.URL
// getDialer either returns an initialized dial handler or creates a
// new one.
getDialer DialerInitializer
// tlsConf is the configuration of TLS.
tlsConf *tls.Config
// connsMu protects conns.
connsMu *sync.Mutex
// logger is used for exchange logging. It is never nil.
logger *slog.Logger
// conns stores the connections ready for reuse. Don't use [sync.Pool]
// here, since there is no need to deallocate these connections.
//
// TODO(e.burkov, ameshkov): Currently connections just stored in FILO
// order, which eventually makes most of them unusable due to timeouts.
// This leads to weak performance for all exchanges coming across such
// connections.
conns []net.Conn
}
// newDoT returns the DNS-over-TLS Upstream.
func newDoT(addr *url.URL, opts *Options) (ups Upstream, err error) {
addPort(addr, defaultPortDoT)
tlsUps := &dnsOverTLS{
addr: addr,
getDialer: newDialerInitializer(addr, opts),
tlsConf: &tls.Config{
ServerName: addr.Hostname(),
RootCAs: opts.RootCAs,
CipherSuites: opts.CipherSuites,
// Use the default capacity for the LRU cache. It may be useful to
// store several caches since the user may be routed to different
// servers in case there's load balancing on the server-side.
ClientSessionCache: tls.NewLRUClientSessionCache(0),
MinVersion: tls.VersionTLS12,
// #nosec G402 -- TLS certificate verification could be disabled by
// configuration.
InsecureSkipVerify: opts.InsecureSkipVerify,
VerifyPeerCertificate: opts.VerifyServerCertificate,
VerifyConnection: opts.VerifyConnection,
},
connsMu: &sync.Mutex{},
logger: opts.Logger,
}
runtime.SetFinalizer(tlsUps, (*dnsOverTLS).Close)
return tlsUps, nil
}
// type check
var _ Upstream = (*dnsOverTLS)(nil)
// Address implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Address() string { return p.addr.String() }
// Exchange implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Exchange(req *dns.Msg) (reply *dns.Msg, err error) {
h, err := p.getDialer()
if err != nil {
return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
}
conn, err := p.conn(h)
if err != nil {
return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
}
reply, err = p.exchangeWithConn(conn, req)
if err != nil {
// The pooled connection might have been closed already, see
// https://github.com/AdguardTeam/dnsproxy/issues/3. The following
// connection from pool may also be malformed, so dial a new one.
err = errors.WithDeferred(err, conn.Close())
p.logger.Debug("dot got bad conn from pool", "addr", p.addr, slogutil.KeyError, err)
// Retry.
conn, err = tlsDial(h, p.tlsConf.Clone())
if err != nil {
return nil, fmt.Errorf(
"dialing %s: connecting to %s: %w",
p.addr,
p.tlsConf.ServerName,
err,
)
}
reply, err = p.exchangeWithConn(conn, req)
if err != nil {
return reply, errors.WithDeferred(err, conn.Close())
}
}
p.putBack(conn)
return reply, nil
}
// Close implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Close() (err error) {
runtime.SetFinalizer(p, nil)
p.connsMu.Lock()
defer p.connsMu.Unlock()
var closeErrs []error
for _, conn := range p.conns {
closeErr := conn.Close()
if closeErr != nil && isCriticalTCP(closeErr) {
closeErrs = append(closeErrs, closeErr)
}
}
return errors.Join(closeErrs...)
}
// conn returns the first available connection from the pool if there is any, or
// dials a new one otherwise.
func (p *dnsOverTLS) conn(h bootstrap.DialHandler) (conn net.Conn, err error) {
// Dial a new connection outside the lock, if needed.
defer func() {
if conn == nil {
conn, err = tlsDial(h, p.tlsConf.Clone())
err = errors.Annotate(err, "connecting to %s: %w", p.tlsConf.ServerName)
}
}()
p.connsMu.Lock()
defer p.connsMu.Unlock()
l := len(p.conns)
if l == 0 {
return nil, nil
}
p.conns, conn = p.conns[:l-1], p.conns[l-1]
err = conn.SetDeadline(time.Now().Add(dialTimeout))
if err != nil {
p.logger.Debug("dot upstream setting deadline to conn from pool", slogutil.KeyError, err)
// If deadLine can't be updated it means that connection was already
// closed.
return nil, nil
}
p.logger.Debug("dot upstream using existing conn", "raddr", conn.RemoteAddr())
return conn, nil
}
func (p *dnsOverTLS) putBack(conn net.Conn) {
p.connsMu.Lock()
defer p.connsMu.Unlock()
p.conns = append(p.conns, conn)
}
// exchangeWithConn tries to exchange the query using conn.
func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, req *dns.Msg) (reply *dns.Msg, err error) {
addr := p.Address()
logBegin(p.logger, addr, networkTCP, req)
defer func() { logFinish(p.logger, addr, networkTCP, err) }()
dnsConn := dns.Conn{Conn: conn}
err = dnsConn.WriteMsg(req)
if err != nil {
return nil, fmt.Errorf("sending request to %s: %w", addr, err)
}
reply, err = dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("reading response from %s: %w", addr, err)
} else if reply.Id != req.Id {
return reply, dns.ErrId
}
return reply, err
}
// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func tlsDial(dialContext bootstrap.DialHandler, conf *tls.Config) (c *tls.Conn, err error) {
// We're using bootstrapped address instead of what's passed to the
// function.
rawConn, err := dialContext(context.Background(), networkTCP, "")
if err != nil {
return nil, err
}
// We want the timeout to cover the whole process: TCP connection and TLS
// handshake dialTimeout will be used as connection deadLine.
conn := tls.Client(rawConn, conf)
err = conn.SetDeadline(time.Now().Add(dialTimeout))
if err != nil {
// Must not happen in normal circumstances.
panic(fmt.Errorf("dnsproxy: tls dial: setting deadline: %w", err))
}
err = conn.Handshake()
if err != nil {
return nil, errors.WithDeferred(err, conn.Close())
}
return conn, nil
}
// isCriticalTCP returns true if err isn't an expected error in terms of closing
// the TCP connection.
func isCriticalTCP(err error) (ok bool) {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return false
}
switch {
case
errors.Is(err, io.EOF),
errors.Is(err, net.ErrClosed),
errors.Is(err, os.ErrDeadlineExceeded),
isConnBroken(err):
return false
default:
return true
}
}
0707010000009B000081A4000000000000000000000001679A649F00001E54000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/upstream/dot_internal_test.gopackage upstream
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net"
"net/url"
"sync"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUpstream_dnsOverTLS(t *testing.T) {
srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
resp := respondToTestMessage(req)
err := w.WriteMsg(resp)
pt := testutil.PanicT{}
require.NoError(pt, err)
})
// Create a DoT upstream that we'll be testing.
addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: true,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// Test that it responds properly.
for range 10 {
checkUpstream(t, u, addr)
}
}
func TestUpstream_dnsOverTLS_race(t *testing.T) {
const count = 10
srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
resp := respondToTestMessage(req)
err := w.WriteMsg(resp)
pt := testutil.PanicT{}
require.NoError(pt, err)
})
// Creating a DoT upstream that we will be testing.
addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: true,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// Use this upstream from multiple goroutines in parallel.
wg := sync.WaitGroup{}
for range count {
wg.Add(1)
go func() {
defer wg.Done()
pt := testutil.PanicT{}
req := createTestMessage()
resp, uErr := u.Exchange(req)
require.NoError(pt, uErr)
requireResponse(pt, req, resp)
}()
}
wg.Wait()
}
// TODO(e.burkov, a.garipov): Add to golibs and use here some kind of helper
// for type assertion of interface types.
func TestUpstream_dnsOverTLS_poolReconnect(t *testing.T) {
srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(req)))
})
// This var is used to store the last connection state in order to check
// if session resumption works as expected.
var lastState tls.ConnectionState
// Init the upstream to the test DoT server that also keeps track of the
// session resumptions.
addr := (&url.URL{
Scheme: "tls",
Host: srv.srv.Listener.Addr().String(),
}).String()
u, err := AddressToUpstream(addr, &Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: true,
VerifyConnection: func(state tls.ConnectionState) error {
lastState = state
return nil
},
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
p := testutil.RequireTypeAssert[*dnsOverTLS](t, u)
// Send the first test message.
req := createTestMessage()
reply, err := u.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, reply)
// Now let's close the pooled connection.
require.Len(t, p.conns, 1)
conn := p.conns[0]
require.NoError(t, conn.Close())
// Send the second test message.
req = createTestMessage()
reply, err = u.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, reply)
// Now assert that the number of connections in the pool is not changed.
require.Len(t, p.conns, 1)
assert.NotSame(t, conn, p.conns[0])
// Check that the session was resumed on the last attempt.
assert.True(t, lastState.DidResume)
}
func TestUpstream_dnsOverTLS_poolDeadline(t *testing.T) {
srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(req)))
})
// Create a DoT upstream that we'll be testing.
addr := (&url.URL{
Scheme: "tls",
Host: srv.srv.Listener.Addr().String(),
}).String()
u, err := AddressToUpstream(addr, &Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: true,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// Send the first test message.
req := createTestMessage()
response, err := u.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, response)
p := testutil.RequireTypeAssert[*dnsOverTLS](t, u)
// Now let's get connection from the pool and use it again.
require.Len(t, p.conns, 1)
conn := p.conns[0]
dialHandler, err := p.getDialer()
require.NoError(t, err)
usedConn, err := p.conn(dialHandler)
require.NoError(t, err)
require.Same(t, usedConn, conn)
response, err = p.exchangeWithConn(conn, req)
require.NoError(t, err)
requireResponse(t, req, response)
// Update the connection's deadLine.
err = conn.SetDeadline(time.Now().Add(10 * time.Hour))
require.NoError(t, err)
p.putBack(conn)
// Get connection from the pool and reuse it.
require.Len(t, p.conns, 1)
conn = p.conns[0]
usedConn, err = p.conn(dialHandler)
require.NoError(t, err)
require.Same(t, usedConn, conn)
response, err = p.exchangeWithConn(usedConn, req)
require.NoError(t, err)
requireResponse(t, req, response)
// Set connection's deadLine to the past and try to reuse it.
err = usedConn.SetDeadline(time.Now().Add(-10 * time.Hour))
require.NoError(t, err)
// Connection with expired deadLine can't be used.
response, err = p.exchangeWithConn(usedConn, req)
require.Error(t, err)
require.Nil(t, response)
}
// testDoTServer is a test DNS-over-TLS server that can be used in unit-tests.
type testDoTServer struct {
// srv is the *dns.Server instance that listens for DoT requests.
srv *dns.Server
// tlsConfig is the TLS configuration that is used for this server.
tlsConfig *tls.Config
// rootCAs is the pool with root certificates used by the test server.
rootCAs *x509.CertPool
// port to which the server listens to.
port int
}
// type check
var _ io.Closer = (*testDoTServer)(nil)
// startDoTServer starts *testDoTServer on a random port.
//
// TODO(e.burkov): Also return address?
func startDoTServer(tb testing.TB, handler dns.HandlerFunc) (s *testDoTServer) {
tb.Helper()
tcpListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(tb, err)
tlsConfig, rootCAs := createServerTLSConfig(tb, "127.0.0.1")
tlsListener := tls.NewListener(tcpListener, tlsConfig)
srv := &dns.Server{
Listener: tlsListener,
TLSConfig: tlsConfig,
Net: "tls",
Handler: handler,
}
go func() {
pt := testutil.PanicT{}
require.NoError(pt, srv.ActivateAndServe())
}()
s = &testDoTServer{
srv: srv,
tlsConfig: tlsConfig,
rootCAs: rootCAs,
port: tcpListener.Addr().(*net.TCPAddr).Port,
}
testutil.CleanupAndRequireSuccess(tb, s.Close)
return s
}
// Close implements the io.Closer interface for *testDoTServer.
func (s *testDoTServer) Close() error {
return s.srv.Shutdown()
}
func BenchmarkDoTUpstream(b *testing.B) {
srv := startDoTServer(b, func(w dns.ResponseWriter, m *dns.Msg) {
err := w.WriteMsg(respondToTestMessage(m))
require.NoError(testutil.PanicT{}, err)
})
addr := (&url.URL{
Scheme: "tls",
Host: srv.srv.Listener.Addr().String(),
}).String()
u, err := AddressToUpstream(addr, &Options{
Logger: slogutil.NewDiscardLogger(),
InsecureSkipVerify: true,
})
require.NoError(b, err)
testutil.CleanupAndRequireSuccess(b, u.Close)
reqChan := make(chan *dns.Msg, 64)
go func() {
for {
reqChan <- createTestMessage()
}
}()
// Wait for channel to fill.
require.Eventually(b, func() bool {
return len(reqChan) == cap(reqChan)
}, time.Second, time.Millisecond)
b.Run("exchange_p", func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(p *testing.PB) {
for p.Next() {
_, _ = u.Exchange(<-reqChan)
}
})
})
}
0707010000009C000081A4000000000000000000000001679A649F00000152000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/dot_unix.go//go:build darwin || freebsd || linux || openbsd || netbsd
package upstream
import (
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/sys/unix"
)
// isConnBroken returns true if err means that a connection is broken.
func isConnBroken(err error) (ok bool) {
return errors.Is(err, unix.EPIPE) || errors.Is(err, unix.ETIMEDOUT)
}
0707010000009D000081A4000000000000000000000001679A649F00000141000000000000000000000000000000000000002800000000dnsproxy-0.75.0/upstream/dot_windows.go//go:build windows
package upstream
import (
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/sys/windows"
)
// isConnBroken returns true if err means that a connection is broken.
func isConnBroken(err error) (ok bool) {
return errors.Is(err, windows.WSAECONNABORTED) || errors.Is(err, windows.WSAECONNRESET)
}
0707010000009E000081A4000000000000000000000001679A649F00000A52000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/upstream/hostsresolver.gopackage upstream
import (
"context"
"fmt"
"io/fs"
"log/slog"
"net/netip"
"slices"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
)
// HostsResolver is a [Resolver] that looks into system hosts files, see
// [hostsfile].
type HostsResolver struct {
// strg contains all the hosts file data needed for lookups.
strg hostsfile.Storage
}
// NewHostsResolver is the resolver based on system hosts files.
func NewHostsResolver(hosts hostsfile.Storage) (hr *HostsResolver) {
return &HostsResolver{
strg: hosts,
}
}
// NewDefaultHostsResolver returns a resolver based on system hosts files
// provided by the [hostsfile.DefaultHostsPaths] and read from rootFSys. In
// case the file by any default path doesn't exist it adds a log debug record.
// If l is nil, [slog.Default] is used.
func NewDefaultHostsResolver(rootFSys fs.FS, l *slog.Logger) (hr *HostsResolver, err error) {
if l == nil {
l = slog.Default()
}
paths, err := hostsfile.DefaultHostsPaths()
if err != nil {
return nil, fmt.Errorf("getting default hosts paths: %w", err)
}
// The error is always nil here since no readers passed.
strg, _ := hostsfile.NewDefaultStorage()
for _, filename := range paths {
err = parseHostsFile(rootFSys, strg, filename, l)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return nil, err
}
}
return NewHostsResolver(strg), nil
}
// parseHostsFile reads a single hosts file from fsys and parses it into hosts.
func parseHostsFile(fsys fs.FS, hosts hostsfile.Set, filename string, l *slog.Logger) (err error) {
f, err := fsys.Open(filename)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
l.Debug("hosts file does not exist", "filename", filename)
return nil
}
// Don't wrap the error since it's already informative enough as is.
return err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
return hostsfile.Parse(hosts, f, nil)
}
// type check
var _ Resolver = (*HostsResolver)(nil)
// LookupNetIP implements the [Resolver] interface for *hostsResolver.
func (hr *HostsResolver) LookupNetIP(
context context.Context,
network string,
host string,
) (addrs []netip.Addr, err error) {
var ipMatches func(netip.Addr) (ok bool)
switch network {
case "ip4":
ipMatches = netip.Addr.Is4
case "ip6":
ipMatches = netip.Addr.Is6
case "ip":
return slices.Clone(hr.strg.ByName(host)), nil
default:
return nil, fmt.Errorf("unsupported network %q", network)
}
for _, addr := range hr.strg.ByName(host) {
if ipMatches(addr) {
addrs = append(addrs, addr)
}
}
return addrs, nil
}
0707010000009F000081A4000000000000000000000001679A649F0000093C000000000000000000000000000000000000002F00000000dnsproxy-0.75.0/upstream/hostsresolver_test.gopackage upstream_test
import (
"context"
"net/netip"
"testing"
"testing/fstest"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostsResolver_LookupNetIP(t *testing.T) {
const hostsData = `
1.2.3.4 host1 host2 ipv4.only
::1 host1 host2 ipv6.only
`
var (
v4Addr = netip.MustParseAddr("1.2.3.4")
v6Addr = netip.MustParseAddr("::1")
)
paths, err := hostsfile.DefaultHostsPaths()
require.NoError(t, err)
require.NotEmpty(t, paths)
fsys := fstest.MapFS{
paths[0]: {
Data: []byte(hostsData),
},
}
hr, err := upstream.NewDefaultHostsResolver(fsys, slogutil.NewDiscardLogger())
require.NoError(t, err)
testCases := []struct {
name string
host string
net string
wantAddrs []netip.Addr
}{{
name: "canonical_any",
host: "host1",
net: "ip",
wantAddrs: []netip.Addr{v4Addr, v6Addr},
}, {
name: "canonical_v4",
host: "host1",
net: "ip4",
wantAddrs: []netip.Addr{v4Addr},
}, {
name: "canonical_v6",
host: "host1",
net: "ip6",
wantAddrs: []netip.Addr{v6Addr},
}, {
name: "alias_any",
host: "host2",
net: "ip",
wantAddrs: []netip.Addr{v4Addr, v6Addr},
}, {
name: "alias_v4",
host: "host2",
net: "ip4",
wantAddrs: []netip.Addr{v4Addr},
}, {
name: "alias_v6",
host: "host2",
net: "ip6",
wantAddrs: []netip.Addr{v6Addr},
}, {
name: "unknown_host",
host: "host3",
net: "ip",
wantAddrs: nil,
}, {
name: "family_mismatch_v4",
host: "ipv6.only",
net: "ip4",
wantAddrs: nil,
}, {
name: "family_mismatch_v6",
host: "ipv4.only",
net: "ip6",
wantAddrs: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var addrs []netip.Addr
addrs, err = hr.LookupNetIP(context.Background(), tc.net, tc.host)
require.NoError(t, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("unsupported_network", func(t *testing.T) {
_, err = hr.LookupNetIP(context.Background(), "ip5", "host1")
testutil.AssertErrorMsg(t, `unsupported network "ip5"`, err)
})
}
070701000000A0000081A4000000000000000000000001679A649F000010C9000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/parallel.gopackage upstream
import (
"fmt"
"slices"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns"
)
const (
// ErrNoUpstreams is returned from the methods that expect at least a single
// upstream to work with when no upstreams specified.
ErrNoUpstreams errors.Error = "no upstream specified"
// ErrNoReply is returned from [ExchangeAll] when no upstreams replied.
ErrNoReply errors.Error = "no reply"
)
// ExchangeParallel returns the first successful response from one of u. It
// returns an error if all upstreams failed to exchange the request.
func ExchangeParallel(ups []Upstream, req *dns.Msg) (reply *dns.Msg, resolved Upstream, err error) {
upsNum := len(ups)
switch upsNum {
case 0:
return nil, nil, ErrNoUpstreams
case 1:
return exchangeSingle(ups[0], req)
default:
// Go on.
}
resCh := make(chan any, upsNum)
for _, f := range ups {
// Use a copy to prevent data races, as [dns.Client] can modify the DNS
// request during the exchange.
//
// TODO(s.chzhen): Consider using buffer pool.
copyReq := req.Copy()
go exchangeAsync(f, copyReq, resCh)
}
errs := []error{}
for range ups {
var r *ExchangeAllResult
r, err = receiveAsyncResult(resCh)
if err != nil {
if !errors.Is(err, ErrNoReply) {
errs = append(errs, err)
}
} else {
return r.Resp, r.Upstream, nil
}
}
// TODO(e.burkov): Probably it's better to return the joined error from
// each upstream that returned no response, and get rid of multiple
// [errors.Is] calls. This will change the behavior though.
if len(errs) == 0 {
return nil, nil, errors.Error("none of upstream servers responded")
}
return nil, nil, errors.Join(errs...)
}
// exchangeSingle returns a successful response and resolver if a DNS lookup was
// successful.
func exchangeSingle(
ups Upstream,
req *dns.Msg,
) (resp *dns.Msg, resolved Upstream, err error) {
resp, err = ups.Exchange(req)
if err != nil {
return nil, nil, err
}
return resp, ups, err
}
// ExchangeAllResult is the successful result of [ExchangeAll] for a single
// upstream.
type ExchangeAllResult struct {
// Resp is the response DNS request resolved into.
Resp *dns.Msg
// Upstream is the upstream that successfully resolved the request.
Upstream Upstream
}
// ExchangeAll returns the responses from all of u. It returns an error only if
// all upstreams failed to exchange the request.
func ExchangeAll(ups []Upstream, req *dns.Msg) (res []ExchangeAllResult, err error) {
upsNum := len(ups)
switch upsNum {
case 0:
return nil, ErrNoUpstreams
case 1:
var reply *dns.Msg
reply, err = ups[0].Exchange(req)
if err != nil {
return nil, err
} else if reply == nil {
return nil, ErrNoReply
}
return []ExchangeAllResult{{Upstream: ups[0], Resp: reply}}, nil
default:
// Go on.
}
res = make([]ExchangeAllResult, 0, upsNum)
var errs []error
resCh := make(chan any, upsNum)
// Start exchanging concurrently.
for _, u := range ups {
// Use a copy to prevent data races, as [dns.Client] can modify the DNS
// request during the exchange.
//
// TODO(s.chzhen): Consider using buffer pool.
copyReq := req.Copy()
go exchangeAsync(u, copyReq, resCh)
}
// Wait for all exchanges to finish.
for range ups {
var r *ExchangeAllResult
r, err = receiveAsyncResult(resCh)
if err != nil {
errs = append(errs, err)
} else {
res = append(res, *r)
}
}
if len(errs) == upsNum {
return res, fmt.Errorf("all upstreams failed: %w", errors.Join(errs...))
}
return slices.Clip(res), nil
}
// receiveAsyncResult receives a single result from resCh or an error from
// errCh. It returns either a non-nil result or an error.
func receiveAsyncResult(resCh chan any) (res *ExchangeAllResult, err error) {
switch res := (<-resCh).(type) {
case error:
return nil, res
case *ExchangeAllResult:
if res.Resp == nil {
return nil, ErrNoReply
}
return res, nil
default:
return nil, fmt.Errorf("unexpected type %T of result", res)
}
}
// exchangeAsync tries to resolve DNS request with one upstream and sends the
// result to respCh.
func exchangeAsync(u Upstream, req *dns.Msg, resCh chan any) {
reply, err := u.Exchange(req)
if err != nil {
resCh <- err
} else {
resCh <- &ExchangeAllResult{Resp: reply, Upstream: u}
}
}
070701000000A1000081A4000000000000000000000001679A649F00000D44000000000000000000000000000000000000003300000000dnsproxy-0.75.0/upstream/parallel_internal_test.gopackage upstream
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
timeout = 2 * time.Second
)
// TestExchangeParallel launches several parallel exchanges
func TestExchangeParallel(t *testing.T) {
upstreams := []Upstream{}
upstreamList := []string{"1.2.3.4:55", "8.8.8.1", "8.8.8.8:53"}
for _, s := range upstreamList {
u, err := AddressToUpstream(s, &Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: timeout,
})
if err != nil {
t.Fatalf("cannot create upstream: %s", err)
}
upstreams = append(upstreams, u)
}
req := createTestMessage()
start := time.Now()
resp, u, err := ExchangeParallel(upstreams, req)
if err != nil {
t.Fatalf("no response from test upstreams: %s", err)
}
if u.Address() != "8.8.8.8:53" {
t.Fatalf("shouldn't happen. This upstream can't resolve DNS request: %s", u.Address())
}
requireResponse(t, req, resp)
elapsed := time.Since(start)
if elapsed > timeout {
t.Fatalf("exchange took more time than the configured timeout: %v", elapsed)
}
}
func TestExchangeParallelEmpty(t *testing.T) {
ups := []Upstream{
&testUpstream{empty: true},
&testUpstream{empty: true},
}
req := createTestMessage()
resp, up, err := ExchangeParallel(ups, req)
require.Error(t, err)
assert.Nil(t, resp)
assert.Nil(t, up)
}
// testUpstream represents a mock upstream structure.
type testUpstream struct {
// addr is a mock A record IP address to be returned.
addr netip.Addr
// err is a mock error to be returned.
err bool
// empty indicates if a nil response is returned.
empty bool
// sleep is a delay before response.
sleep time.Duration
}
// type check
var _ Upstream = (*testUpstream)(nil)
// Exchange implements the [Upstream] interface for *testUpstream.
func (u *testUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
if u.sleep != 0 {
time.Sleep(u.sleep)
}
if u.empty {
return nil, nil
}
if u.err {
return nil, fmt.Errorf("upstream error")
}
resp = &dns.Msg{}
resp.SetReply(req)
if u.addr != (netip.Addr{}) {
a := dns.A{
A: u.addr.AsSlice(),
}
resp.Answer = append(resp.Answer, &a)
}
return resp, nil
}
// Address implements the [Upstream] interface for *testUpstream.
func (u *testUpstream) Address() (addr string) {
return ""
}
// Close implements the [Upstream] interface for *testUpstream.
func (u *testUpstream) Close() (err error) {
return nil
}
func TestExchangeAll(t *testing.T) {
delayedAnsAddr := netip.MustParseAddr("1.1.1.1")
ansAddr := netip.MustParseAddr("3.3.3.3")
ups := []Upstream{&testUpstream{
addr: delayedAnsAddr,
sleep: 100 * time.Millisecond,
}, &testUpstream{
err: true,
}, &testUpstream{
addr: ansAddr,
}}
req := createHostTestMessage("test.org")
res, err := ExchangeAll(ups, req)
require.NoError(t, err)
require.Len(t, res, 2)
resp := res[0].Resp
require.NotNil(t, resp)
require.NotEmpty(t, resp.Answer)
require.IsType(t, new(dns.A), resp.Answer[0])
ip := resp.Answer[0].(*dns.A).A
assert.Equal(t, ansAddr.AsSlice(), []byte(ip))
resp = res[1].Resp
require.NotNil(t, resp)
require.NotEmpty(t, resp.Answer)
require.IsType(t, new(dns.A), resp.Answer[0])
ip = resp.Answer[0].(*dns.A).A
assert.Equal(t, delayedAnsAddr.AsSlice(), []byte(ip))
}
070701000000A2000081A4000000000000000000000001679A649F000015FA000000000000000000000000000000000000002200000000dnsproxy-0.75.0/upstream/plain.gopackage upstream
import (
"context"
"fmt"
"io"
"log/slog"
"net"
"net/url"
"strings"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
)
// network is the semantic type alias of the network to pass to dialing
// functions. It's either [networkUDP] or [networkTCP]. It may also be used as
// URL scheme for plain upstreams.
type network = string
const (
// networkUDP is the UDP network.
networkUDP network = "udp"
// networkTCP is the TCP network.
networkTCP network = "tcp"
)
// plainDNS implements the [Upstream] interface for the regular DNS protocol.
type plainDNS struct {
// addr is the DNS server URL. Scheme is always "udp" or "tcp".
addr *url.URL
// logger is used for exchange logging. It is never nil.
logger *slog.Logger
// getDialer either returns an initialized dial handler or creates a new
// one.
getDialer DialerInitializer
// net is the network of the connections.
net network
// timeout is the timeout for DNS requests.
timeout time.Duration
}
// newPlain returns the plain DNS Upstream. addr.Scheme should be either "udp"
// or "tcp".
func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) {
switch addr.Scheme {
case networkUDP, networkTCP:
// Go on.
default:
return nil, fmt.Errorf("unsupported url scheme: %s", addr.Scheme)
}
addPort(addr, defaultPortPlain)
return &plainDNS{
addr: addr,
logger: opts.Logger,
getDialer: newDialerInitializer(addr, opts),
net: addr.Scheme,
timeout: opts.Timeout,
}, nil
}
// type check
var _ Upstream = &plainDNS{}
// Address implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Address() string {
switch p.net {
case networkUDP:
return p.addr.Host
case networkTCP:
return p.addr.String()
default:
panic(fmt.Sprintf("unexpected network: %s", p.net))
}
}
// dialExchange performs a DNS exchange with the specified dial handler.
// network must be either [networkUDP] or [networkTCP].
func (p *plainDNS) dialExchange(
network network,
dial bootstrap.DialHandler,
req *dns.Msg,
) (resp *dns.Msg, err error) {
addr := p.Address()
client := &dns.Client{Timeout: p.timeout}
conn := &dns.Conn{}
if network == networkUDP {
conn.UDPSize = dns.MinMsgSize
}
logBegin(p.logger, addr, network, req)
defer func() { logFinish(p.logger, addr, network, err) }()
ctx := context.Background()
conn.Conn, err = dial(ctx, network, "")
if err != nil {
return nil, fmt.Errorf("dialing %s over %s: %w", p.addr.Host, network, err)
}
defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn)
resp, _, err = client.ExchangeWithConn(req, conn)
if isExpectedConnErr(err) {
conn.Conn, err = dial(ctx, network, "")
if err != nil {
return nil, fmt.Errorf("dialing %s over %s again: %w", p.addr.Host, network, err)
}
defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn)
resp, _, err = client.ExchangeWithConn(req, conn)
}
if err != nil {
return resp, fmt.Errorf("exchanging with %s over %s: %w", addr, network, err)
}
return resp, validatePlainResponse(req, resp)
}
// isExpectedConnErr returns true if the error is expected. In this case,
// we will make a second attempt to process the request.
func isExpectedConnErr(err error) (is bool) {
var netErr net.Error
return err != nil && (errors.As(err, &netErr) || errors.Is(err, io.EOF))
}
// Exchange implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
dial, err := p.getDialer()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
addr := p.Address()
resp, err = p.dialExchange(p.net, dial, req)
if p.net != networkUDP {
// The network is already TCP.
return resp, err
}
if resp == nil {
// There is likely an error with the upstream.
return resp, err
}
if errors.Is(err, errQuestion) {
// The upstream responds with malformed messages, so try TCP.
p.logger.Debug(
"plain response is malformed, using tcp",
"addr", addr,
slogutil.KeyError, err,
)
return p.dialExchange(networkTCP, dial, req)
} else if resp.Truncated {
// Fallback to TCP on truncated responses.
p.logger.Debug(
"plain response is truncated, using tcp",
"question", &req.Question[0],
"addr", addr,
)
return p.dialExchange(networkTCP, dial, req)
}
// There is either no error or the error isn't related to the received
// message.
return resp, err
}
// Close implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Close() (err error) {
return nil
}
// errQuestion is returned when a message has malformed question section.
const errQuestion errors.Error = "bad question section"
// validatePlainResponse validates resp from an upstream DNS server for
// compliance with req. Any error returned wraps [ErrQuestion], since it
// essentially validates the question section of resp.
func validatePlainResponse(req, resp *dns.Msg) (err error) {
if qlen := len(resp.Question); qlen != 1 {
return fmt.Errorf("%w: only 1 question allowed; got %d", errQuestion, qlen)
}
reqQ, respQ := req.Question[0], resp.Question[0]
if reqQ.Qtype != respQ.Qtype {
return fmt.Errorf("%w: mismatched type %s", errQuestion, dns.Type(respQ.Qtype))
}
// Compare the names case-insensitively, just like CoreDNS does.
if !strings.EqualFold(reqQ.Name, respQ.Name) {
return fmt.Errorf("%w: mismatched name %q", errQuestion, respQ.Name)
}
return nil
}
070701000000A3000081A4000000000000000000000001679A649F0000124B000000000000000000000000000000000000003000000000dnsproxy-0.75.0/upstream/plain_internal_test.gopackage upstream
import (
"fmt"
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUpstream_plainDNS(t *testing.T) {
srv := startDNSServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
resp := respondToTestMessage(req)
err := w.WriteMsg(resp)
pt := testutil.PanicT{}
require.NoError(pt, err)
})
testutil.CleanupAndRequireSuccess(t, srv.Close)
addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
for range 10 {
checkUpstream(t, u, addr)
}
}
func TestUpstream_plainDNS_badID(t *testing.T) {
req := createTestMessage()
badIDResp := respondToTestMessage(req)
badIDResp.Id++
srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) {
require.NoError(testutil.PanicT{}, w.WriteMsg(badIDResp))
})
testutil.CleanupAndRequireSuccess(t, srv.Close)
addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{
Logger: slogutil.NewDiscardLogger(),
// Use a shorter timeout to speed up the test.
Timeout: 100 * time.Millisecond,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
resp, err := u.Exchange(req)
var netErr net.Error
require.ErrorAs(t, err, &netErr)
assert.True(t, netErr.Timeout())
assert.Nil(t, resp)
}
func TestUpstream_plainDNS_fallbackToTCP(t *testing.T) {
req := createTestMessage()
goodResp := respondToTestMessage(req)
truncResp := goodResp.Copy()
truncResp.Truncated = true
badQNameResp := goodResp.Copy()
badQNameResp.Question[0].Name = "bad." + req.Question[0].Name
badQTypeResp := goodResp.Copy()
badQTypeResp.Question[0].Qtype = dns.TypeCNAME
testCases := []struct {
udpResp *dns.Msg
name string
wantUDP int
wantTCP int
}{{
udpResp: goodResp,
name: "all_right",
wantUDP: 1,
wantTCP: 0,
}, {
udpResp: truncResp,
name: "truncated_response",
wantUDP: 1,
wantTCP: 1,
}, {
udpResp: badQNameResp,
name: "bad_qname",
wantUDP: 1,
wantTCP: 1,
}, {
udpResp: badQTypeResp,
name: "bad_qtype",
wantUDP: 1,
wantTCP: 1,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var udpReqNum, tcpReqNum atomic.Uint32
srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) {
var resp *dns.Msg
if w.RemoteAddr().Network() == networkUDP {
udpReqNum.Add(1)
resp = tc.udpResp
} else {
tcpReqNum.Add(1)
resp = goodResp
}
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
})
testutil.CleanupAndRequireSuccess(t, srv.Close)
addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{
Logger: slogutil.NewDiscardLogger(),
// Use a shorter timeout to speed up the test.
Timeout: 100 * time.Millisecond,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
resp, err := u.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, resp)
assert.Equal(t, tc.wantUDP, int(udpReqNum.Load()))
assert.Equal(t, tc.wantTCP, int(tcpReqNum.Load()))
})
}
}
// testDNSServer is a simple DNS server that can be used in unit-tests.
type testDNSServer struct {
udpListener net.PacketConn
tcpListener net.Listener
udpSrv *dns.Server
tcpSrv *dns.Server
port int
}
// type check
var _ io.Closer = (*testDNSServer)(nil)
// startDNSServer a test DNS server.
func startDNSServer(t *testing.T, handler dns.HandlerFunc) (s *testDNSServer) {
t.Helper()
s = &testDNSServer{}
udpListener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
s.port = udpListener.LocalAddr().(*net.UDPAddr).Port
s.udpListener = udpListener
s.tcpListener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", s.port))
require.NoError(t, err)
s.udpSrv = &dns.Server{
PacketConn: s.udpListener,
Handler: handler,
}
s.tcpSrv = &dns.Server{
Listener: s.tcpListener,
Handler: handler,
}
go func() {
pt := testutil.PanicT{}
require.NoError(pt, s.udpSrv.ActivateAndServe())
}()
go func() {
pt := testutil.PanicT{}
require.NoError(pt, s.tcpSrv.ActivateAndServe())
}()
return s
}
// Close implements the io.Closer interface for *testDNSServer.
func (s *testDNSServer) Close() (err error) {
udpErr := s.udpSrv.Shutdown()
tcpErr := s.tcpSrv.Shutdown()
return errors.WithDeferred(udpErr, tcpErr)
}
070701000000A4000081A4000000000000000000000001679A649F000023D5000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/resolver.gopackage upstream
import (
"context"
"fmt"
"math"
"net/netip"
"net/url"
"slices"
"strings"
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns"
)
// Resolver resolves the hostnames to IP addresses. Note, that [net.Resolver]
// from standard library also implements this interface.
type Resolver = bootstrap.Resolver
// StaticResolver is a resolver which always responds with an underlying slice
// of IP addresses.
type StaticResolver = bootstrap.StaticResolver
// ParallelResolver is a slice of resolvers that are queried concurrently until
// the first successful response is returned, as opposed to all resolvers being
// queried in order in [ConsequentResolver].
type ParallelResolver = bootstrap.ParallelResolver
// ConsequentResolver is a slice of resolvers that are queried in order until
// the first successful non-empty response, as opposed to just successful
// response requirement in [ParallelResolver].
type ConsequentResolver = bootstrap.ConsequentResolver
// UpstreamResolver is a wrapper around Upstream that implements the
// [bootstrap.Resolver] interface.
type UpstreamResolver struct {
// Upstream is used for lookups. It must not be nil.
Upstream
}
// NewUpstreamResolver creates an upstream that can be used as bootstrap
// [Resolver]. resolverAddress format is the same as in the
// [AddressToUpstream]. If the upstream can't be used as a bootstrap, the
// returned error will have the underlying type of [NotBootstrapError], and r
// itself will be fully usable. Closing r.Upstream is caller's responsibility.
func NewUpstreamResolver(resolverAddress string, opts *Options) (r *UpstreamResolver, err error) {
upsOpts := &Options{}
// TODO(ameshkov): Aren't other options needed here?
if opts != nil {
upsOpts.Timeout = opts.Timeout
upsOpts.VerifyServerCertificate = opts.VerifyServerCertificate
upsOpts.PreferIPv6 = opts.PreferIPv6
upsOpts.Logger = opts.Logger
}
ups, err := AddressToUpstream(resolverAddress, upsOpts)
if err != nil {
err = fmt.Errorf("upstream bootstrap: creating upstream: %w", err)
return nil, err
}
return &UpstreamResolver{Upstream: ups}, validateBootstrap(ups)
}
// NotBootstrapError is returned by [AddressToUpstream] when the parsed upstream
// can't be used as a bootstrap and wraps the actual reason.
type NotBootstrapError struct {
// err is the actual reason why the upstream can't be used as a bootstrap.
err error
}
// type check
var _ error = NotBootstrapError{}
// Error implements the [error] interface for NotBootstrapError.
func (e NotBootstrapError) Error() (msg string) {
return fmt.Sprintf("not a bootstrap: %s", e.err)
}
// type check
var _ errors.Wrapper = NotBootstrapError{}
// Unwrap implements the [errors.Wrapper] interface.
func (e NotBootstrapError) Unwrap() (reason error) {
return e.err
}
// validateBootstrap returns an error if u can't be used as a bootstrap.
func validateBootstrap(u Upstream) (err error) {
var upsURL *url.URL
switch u := u.(type) {
case *dnsCrypt:
return nil
case *plainDNS:
upsURL = u.addr
case *dnsOverTLS:
upsURL = u.addr
case *dnsOverHTTPS:
upsURL = u.addr
case *dnsOverQUIC:
upsURL = u.addr
default:
return fmt.Errorf("unknown upstream type: %T", u)
}
// Make sure the upstream doesn't need a bootstrap.
_, err = netip.ParseAddr(upsURL.Hostname())
if err != nil {
return NotBootstrapError{err: err}
}
return nil
}
// type check
var _ Resolver = &UpstreamResolver{}
// LookupNetIP implements the [Resolver] interface for *UpstreamResolver. It
// doesn't consider the TTL of the DNS records.
//
// TODO(e.burkov): Investigate why the empty slice is returned instead of nil.
func (r *UpstreamResolver) LookupNetIP(
ctx context.Context,
network bootstrap.Network,
host string,
) (ips []netip.Addr, err error) {
if host == "" {
return nil, nil
}
host = dns.Fqdn(strings.ToLower(host))
res, err := r.lookupNetIP(ctx, network, host)
if err != nil {
return []netip.Addr{}, err
}
return res.addrs, err
}
// ipResult reflects a single A/AAAA record from the DNS response. It's used
// to cache the results of lookups.
type ipResult struct {
expire time.Time
addrs []netip.Addr
}
// lookupNetIP performs a DNS lookup of host and returns the result. network
// must be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6], or
// [bootstrap.NetworkIP]. host must be in a lower-case FQDN form.
//
// TODO(e.burkov): Use context.
func (r *UpstreamResolver) lookupNetIP(
_ context.Context,
network bootstrap.Network,
host string,
) (result *ipResult, err error) {
switch network {
case bootstrap.NetworkIP4, bootstrap.NetworkIP6:
return r.request(host, network)
case bootstrap.NetworkIP:
// Go on.
default:
return result, fmt.Errorf("unsupported network %s", network)
}
resCh := make(chan any, 2)
go r.resolveAsync(resCh, host, bootstrap.NetworkIP4)
go r.resolveAsync(resCh, host, bootstrap.NetworkIP6)
var errs []error
result = &ipResult{}
for range 2 {
switch res := <-resCh; res := res.(type) {
case error:
errs = append(errs, res)
case *ipResult:
if result.expire.Equal(time.Time{}) || res.expire.Before(result.expire) {
result.expire = res.expire
}
result.addrs = append(result.addrs, res.addrs...)
}
}
return result, errors.Join(errs...)
}
// request performs a single DNS lookup of host and returns all the valid
// addresses from the answer section of the response. network must be either
// [bootstrap.NetworkIP4], or [bootstrap.NetworkIP6]. host must be in a
// lower-case FQDN form.
//
// TODO(e.burkov): Consider NS and Extra sections when setting TTL. Check out
// what RFCs say about it.
func (r *UpstreamResolver) request(host string, n bootstrap.Network) (res *ipResult, err error) {
var qtype uint16
switch n {
case bootstrap.NetworkIP4:
qtype = dns.TypeA
case bootstrap.NetworkIP6:
qtype = dns.TypeAAAA
default:
panic(fmt.Sprintf("unsupported network %q", n))
}
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: host,
Qtype: qtype,
Qclass: dns.ClassINET,
}},
}
// As per [Upstream.Exchange] documentation, the response is always returned
// if no error occurred.
resp, err := r.Exchange(req)
if err != nil {
return res, err
}
res = &ipResult{
expire: time.Now(),
addrs: make([]netip.Addr, 0, len(resp.Answer)),
}
var minTTL uint32 = math.MaxUint32
for _, rr := range resp.Answer {
ip := proxyutil.IPFromRR(rr)
if !ip.IsValid() {
continue
}
res.addrs = append(res.addrs, ip)
minTTL = min(minTTL, rr.Header().Ttl)
}
res.expire = res.expire.Add(time.Duration(minTTL) * time.Second)
return res, nil
}
// resolveAsync performs a single DNS lookup and sends the result to ch. It's
// intended to be used as a goroutine.
func (r *UpstreamResolver) resolveAsync(resCh chan<- any, host, network string) {
res, err := r.request(host, network)
if err != nil {
resCh <- err
} else {
resCh <- res
}
}
// CachingResolver is a [Resolver] that caches the results of lookups. It's
// required to be created with [NewCachingResolver].
type CachingResolver struct {
// resolver is the underlying resolver to use for lookups.
resolver *UpstreamResolver
// mu protects cache and it's elements.
mu *sync.RWMutex
// cache is the set of resolved hostnames mapped to cached addresses.
//
// TODO(e.burkov): Use expiration cache.
cache map[string]*ipResult
}
// NewCachingResolver creates a new caching resolver that uses r for lookups.
func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) {
return &CachingResolver{
resolver: r,
mu: &sync.RWMutex{},
cache: map[string]*ipResult{},
}
}
// type check
var _ Resolver = (*CachingResolver)(nil)
// LookupNetIP implements the [Resolver] interface for *CachingResolver.
//
// TODO(e.burkov): It may appear that several concurrent lookup results rewrite
// each other in the cache.
func (r *CachingResolver) LookupNetIP(
ctx context.Context,
network bootstrap.Network,
host string,
) (addrs []netip.Addr, err error) {
now := time.Now()
host = dns.Fqdn(strings.ToLower(host))
addrs = r.findCached(host, now)
if addrs != nil {
return slices.Clone(addrs), nil
}
res, err := r.resolver.lookupNetIP(ctx, network, host)
if err != nil {
return []netip.Addr{}, err
}
r.setCached(host, res)
return slices.Clone(res.addrs), nil
}
// findCached returns the cached addresses for host if it's not expired yet, and
// the corresponding cached result, if any. It's safe for concurrent use.
func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip.Addr) {
r.mu.RLock()
defer r.mu.RUnlock()
res, ok := r.cache[host]
if !ok || res.expire.Before(now) {
return nil
}
return res.addrs
}
// setCached sets the result into the address cache for host. It's safe for
// concurrent use.
func (r *CachingResolver) setCached(host string, res *ipResult) {
r.mu.Lock()
defer r.mu.Unlock()
r.cache[host] = res
}
070701000000A5000081A4000000000000000000000001679A649F000009CF000000000000000000000000000000000000003300000000dnsproxy-0.75.0/upstream/resolver_internal_test.gopackage upstream
import (
"context"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCachingResolver_staleness(t *testing.T) {
ip4 := netip.MustParseAddr("1.2.3.4")
ip6 := netip.MustParseAddr("2001:db8::1")
const (
smallTTL = 10 * time.Second
largeTTL = 1000 * time.Second
fqdn = "test.fully.qualified.name."
)
onExchange := func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = (&dns.Msg{}).SetReply(req)
hdr := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: req.Question[0].Qtype,
Class: dns.ClassINET,
}
var rr dns.RR
switch q := req.Question[0]; q.Qtype {
case dns.TypeA:
hdr.Ttl = uint32(smallTTL.Seconds())
rr = &dns.A{Hdr: hdr, A: ip4.AsSlice()}
case dns.TypeAAAA:
hdr.Ttl = uint32(largeTTL.Seconds())
rr = &dns.AAAA{Hdr: hdr, AAAA: ip6.AsSlice()}
default:
require.Contains(testutil.PanicT{}, []uint16{dns.TypeA, dns.TypeAAAA}, q.Qtype)
}
resp.Answer = append(resp.Answer, rr)
return resp, nil
}
ups := &dnsproxytest.FakeUpstream{
OnAddress: func() (_ string) { panic("not implemented") },
OnClose: func() (_ error) { panic("not implemented") },
OnExchange: onExchange,
}
r := NewCachingResolver(&UpstreamResolver{Upstream: ups})
require.True(t, t.Run("resolve", func(t *testing.T) {
testCases := []struct {
name string
network bootstrap.Network
want []netip.Addr
}{{
name: "ip4",
network: bootstrap.NetworkIP4,
want: []netip.Addr{ip4},
}, {
name: "ip6",
network: bootstrap.NetworkIP6,
want: []netip.Addr{ip6},
}, {
name: "both",
network: bootstrap.NetworkIP,
want: []netip.Addr{ip4, ip6},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.name != "both" {
t.Skip(`TODO(e.burkov): Bootstrap now only uses "ip" network, see TODO there.`)
}
res, err := r.LookupNetIP(context.Background(), tc.network, fqdn)
require.NoError(t, err)
assert.ElementsMatch(t, tc.want, res)
})
}
}))
t.Run("staleness", func(t *testing.T) {
now := time.Now()
cached := r.findCached(fqdn, now)
require.ElementsMatch(t, []netip.Addr{ip4, ip6}, cached)
cached = r.findCached(fqdn, now.Add(smallTTL+time.Second))
require.Empty(t, cached)
})
}
070701000000A6000081A4000000000000000000000001679A649F00000C4C000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/upstream/resolver_test.gopackage upstream_test
import (
"context"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewUpstreamResolver(t *testing.T) {
ups := &dnsproxytest.FakeUpstream{
OnAddress: func() (_ string) { panic("not implemented") },
OnClose: func() (_ error) { panic("not implemented") },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = (&dns.Msg{}).SetReply(req)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 60,
},
A: netip.MustParseAddr("1.2.3.4").AsSlice(),
}}
return resp, nil
},
}
r := &upstream.UpstreamResolver{Upstream: ups}
ipAddrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com")
require.NoError(t, err)
assert.NotEmpty(t, ipAddrs)
}
func TestNewUpstreamResolver_validity(t *testing.T) {
t.Parallel()
withTimeoutOpt := &upstream.Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: 3 * time.Second,
}
testCases := []struct {
name string
addr string
wantErrMsg string
}{{
name: "udp",
addr: "1.1.1.1:53",
wantErrMsg: "",
}, {
name: "dot",
addr: "tls://1.1.1.1",
wantErrMsg: "",
}, {
name: "doh",
addr: "https://1.1.1.1/dns-query",
wantErrMsg: "",
}, {
name: "sdns",
addr: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
wantErrMsg: "",
}, {
name: "tcp",
addr: "tcp://9.9.9.9",
wantErrMsg: "",
}, {
name: "invalid_tls",
addr: "tls://dns.adguard.com",
wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` +
`unexpected character (at "dns.adguard.com")`,
}, {
name: "invalid_https",
addr: "https://dns.adguard.com/dns-query",
wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` +
`unexpected character (at "dns.adguard.com")`,
}, {
name: "invalid_tcp",
addr: "tcp://dns.adguard.com",
wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` +
`unexpected character (at "dns.adguard.com")`,
}, {
name: "invalid_no_scheme",
addr: "dns.adguard.com",
wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` +
`unexpected character (at "dns.adguard.com")`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r, err := upstream.NewUpstreamResolver(tc.addr, withTimeoutOpt)
if tc.wantErrMsg != "" {
assert.Equal(t, tc.wantErrMsg, err.Error())
if nberr := (&upstream.NotBootstrapError{}); errors.As(err, &nberr) {
assert.NotNil(t, r)
}
return
}
require.NoError(t, err)
addrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com")
require.NoError(t, err)
assert.NotEmpty(t, addrs)
})
}
}
070701000000A7000081A4000000000000000000000001679A649F00003215000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/upstream.go// Package upstream implements DNS clients for all known DNS encryption
// protocols.
package upstream
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/ameshkov/dnscrypt/v2"
"github.com/ameshkov/dnsstamps"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
)
// Upstream is an interface for a DNS resolver. All the methods must be safe
// for concurrent use.
type Upstream interface {
// Exchange sends req to this upstream and returns the response that has
// been received or an error if something went wrong. The implementations
// must not modify req as well as the caller must not modify it until the
// method returns. It shouldn't be called after closing.
Exchange(req *dns.Msg) (resp *dns.Msg, err error)
// Address returns the human-readable address of the upstream DNS resolver.
// It may differ from what was passed to [AddressToUpstream].
Address() (addr string)
// Closer used to close the upstreams properly.
io.Closer
}
// QUICTraceFunc is a function that returns a [logging.ConnectionTracer]
// specific for a given role and connection ID.
type QUICTraceFunc func(
ctx context.Context,
role logging.Perspective,
connID quic.ConnectionID,
) (tracer *logging.ConnectionTracer)
// Options for AddressToUpstream func. With these options we can configure the
// upstream properties.
type Options struct {
// Logger is used for logging during parsing and upstream exchange. If nil,
// [slog.Default] is used.
Logger *slog.Logger
// VerifyServerCertificate is used to set the VerifyPeerCertificate property
// of the *tls.Config for DNS-over-HTTPS, DNS-over-QUIC, and DNS-over-TLS.
VerifyServerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
// VerifyConnection is used to set the VerifyConnection property
// of the *tls.Config for DNS-over-HTTPS, DNS-over-QUIC, and DNS-over-TLS.
VerifyConnection func(state tls.ConnectionState) error
// VerifyDNSCryptCertificate is the callback the DNSCrypt server certificate
// will be passed to. It's called in dnsCrypt.exchangeDNSCrypt.
// Upstream.Exchange method returns any error caused by it.
VerifyDNSCryptCertificate func(cert *dnscrypt.Cert) error
// QUICTracer is an optional callback that allows tracing every QUIC
// connection and logging every packet that goes through.
QUICTracer QUICTraceFunc
// RootCAs is the CertPool that must be used by all upstreams. Redefining
// RootCAs makes sense on iOS to overcome the 15MB memory limit of the
// NEPacketTunnelProvider.
RootCAs *x509.CertPool
// CipherSuites is a custom list of TLSv1.2 ciphers.
CipherSuites []uint16
// Bootstrap is used to resolve upstreams' hostnames. If nil, the
// [net.DefaultResolver] will be used.
Bootstrap Resolver
// HTTPVersions is a list of HTTP versions that should be supported by the
// DNS-over-HTTPS client. If not set, HTTP/1.1 and HTTP/2 will be used.
HTTPVersions []HTTPVersion
// Timeout is the default upstream timeout. It's also used as a timeout for
// bootstrap DNS requests. Zero value disables the timeout.
Timeout time.Duration
// InsecureSkipVerify disables verifying the server's certificate.
InsecureSkipVerify bool
// PreferIPv6 tells the bootstrapper to prefer IPv6 addresses for an
// upstream.
PreferIPv6 bool
}
// Clone copies o to a new struct. Note, that this is not a deep clone.
func (o *Options) Clone() (clone *Options) {
return &Options{
Bootstrap: o.Bootstrap,
Timeout: o.Timeout,
HTTPVersions: o.HTTPVersions,
VerifyServerCertificate: o.VerifyServerCertificate,
VerifyConnection: o.VerifyConnection,
VerifyDNSCryptCertificate: o.VerifyDNSCryptCertificate,
InsecureSkipVerify: o.InsecureSkipVerify,
PreferIPv6: o.PreferIPv6,
QUICTracer: o.QUICTracer,
RootCAs: o.RootCAs,
CipherSuites: o.CipherSuites,
Logger: o.Logger,
}
}
// HTTPVersion is an enumeration of the HTTP versions that we support. Values
// that we use in this enumeration are also used as ALPN values.
type HTTPVersion string
const (
// HTTPVersion11 is HTTP/1.1.
HTTPVersion11 HTTPVersion = "http/1.1"
// HTTPVersion2 is HTTP/2.
HTTPVersion2 HTTPVersion = "h2"
// HTTPVersion3 is HTTP/3.
HTTPVersion3 HTTPVersion = "h3"
)
// DefaultHTTPVersions is the list of HTTPVersion that we use by default in
// the DNS-over-HTTPS client.
var DefaultHTTPVersions = []HTTPVersion{HTTPVersion11, HTTPVersion2}
const (
// defaultPortPlain is the default port for plain DNS.
defaultPortPlain = 53
// defaultPortDoH is the default port for DNS-over-HTTPS.
defaultPortDoH = 443
// defaultPortDoT is the default port for DNS-over-TLS.
defaultPortDoT = 853
// defaultPortDoQ is the default port for DNS-over-QUIC. Prior to version
// -10 of the draft experiments were directed to use ports 8853, 784.
//
// See https://www.rfc-editor.org/rfc/rfc9250.html#name-port-selection.
defaultPortDoQ = 853
)
// AddressToUpstream converts addr to an Upstream using the specified options.
// addr can be either a URL, or a plain address, either a domain name or an IP.
//
// - 1.2.3.4 or 1.2.3.4:4321 for plain DNS using IP address;
// - udp://5.3.5.3:53 or 5.3.5.3:53 for plain DNS using IP address;
// - udp://name.server:53 or name.server:53 for plain DNS using domain name;
// - tcp://5.3.5.3:53 for plain DNS-over-TCP using IP address;
// - tcp://name.server:53 for plain DNS-over-TCP using domain name;
// - tls://5.3.5.3:853 for DNS-over-TLS using IP address;
// - tls://name.server:853 for DNS-over-TLS using domain name;
// - https://5.3.5.3:443/dns-query for DNS-over-HTTPS using IP address;
// - https://name.server:443/dns-query for DNS-over-HTTPS using domain name;
// - quic://5.3.5.3:853 for DNS-over-QUIC using IP address;
// - quic://name.server:853 for DNS-over-QUIC using domain name;
// - h3://dns.google for DNS-over-HTTPS that only works with HTTP/3;
// - sdns://... for DNS stamp, see https://dnscrypt.info/stamps-specifications.
//
// If addr doesn't have port specified, the default port of the appropriate
// protocol will be used.
//
// opts are applied to the u and shouldn't be modified afterwards, nil value is
// valid.
//
// TODO(e.burkov): Clone opts?
func AddressToUpstream(addr string, opts *Options) (u Upstream, err error) {
if opts == nil {
opts = &Options{}
}
if opts.Logger == nil {
opts.Logger = slog.Default()
}
var uu *url.URL
if strings.Contains(addr, "://") {
uu, err = url.Parse(addr)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", addr, err)
}
} else {
uu = &url.URL{
Scheme: "udp",
Host: addr,
}
}
err = validateUpstreamURL(uu)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return urlToUpstream(uu, opts)
}
// validateUpstreamURL returns an error if the upstream URL is not valid.
func validateUpstreamURL(u *url.URL) (err error) {
if u.Scheme == "sdns" {
return nil
}
host := u.Host
// TODO(s.chzhen): Consider using [netutil.SplitHostPort].
h, port, splitErr := net.SplitHostPort(host)
if splitErr == nil {
// Validate port.
_, err = strconv.ParseUint(port, 10, 16)
if err != nil {
return fmt.Errorf("invalid port %s: %w", port, err)
}
host = h
}
// minEnclosedIPv6Len is the minimum length of an IP address enclosed in
// square brackets.
const minEnclosedIPv6Len = len("[::]")
possibleIP := host
if l := len(host); l >= minEnclosedIPv6Len && host[0] == '[' && host[l-1] == ']' {
// Might be an IPv6 address enclosed in square brackets with no port.
//
// See https://github.com/AdguardTeam/dnsproxy/issues/379.
possibleIP = host[1 : l-1]
}
if netutil.IsValidIPString(possibleIP) {
return nil
}
err = netutil.ValidateDomainName(host)
if err != nil {
return fmt.Errorf("invalid address %s: %w", host, err)
}
return nil
}
// urlToUpstream converts uu to an Upstream using opts.
func urlToUpstream(uu *url.URL, opts *Options) (u Upstream, err error) {
switch sch := uu.Scheme; sch {
case "sdns":
return parseStamp(uu, opts)
case "udp", "tcp":
return newPlain(uu, opts)
case "quic":
return newDoQ(uu, opts)
case "tls":
return newDoT(uu, opts)
case "h3", "https":
return newDoH(uu, opts)
default:
return nil, fmt.Errorf("unsupported url scheme: %s", sch)
}
}
// parseStamp converts a DNS stamp to an Upstream.
func parseStamp(upsURL *url.URL, opts *Options) (u Upstream, err error) {
stamp, err := dnsstamps.NewServerStampFromString(upsURL.String())
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", upsURL, err)
}
// TODO(e.burkov): Port?
if stamp.ServerAddrStr != "" {
host, _, sErr := netutil.SplitHostPort(stamp.ServerAddrStr)
if sErr != nil {
host = stamp.ServerAddrStr
}
var ip netip.Addr
ip, err = netip.ParseAddr(host)
if err != nil {
return nil, fmt.Errorf("invalid server stamp address %s", stamp.ServerAddrStr)
}
opts.Bootstrap = StaticResolver{ip}
}
switch stamp.Proto {
case dnsstamps.StampProtoTypePlain:
return newPlain(&url.URL{Scheme: "udp", Host: stamp.ServerAddrStr}, opts)
case dnsstamps.StampProtoTypeDNSCrypt:
return newDNSCrypt(upsURL, opts), nil
case dnsstamps.StampProtoTypeDoH:
return newDoH(&url.URL{Scheme: "https", Host: stamp.ProviderName, Path: stamp.Path}, opts)
case dnsstamps.StampProtoTypeDoQ:
return newDoQ(&url.URL{Scheme: "quic", Host: stamp.ProviderName, Path: stamp.Path}, opts)
case dnsstamps.StampProtoTypeTLS:
return newDoT(&url.URL{Scheme: "tls", Host: stamp.ProviderName}, opts)
default:
return nil, fmt.Errorf("unsupported stamp protocol %s", &stamp.Proto)
}
}
// addPort appends port to u if it's absent.
func addPort(u *url.URL, port uint16) {
if u != nil {
_, _, err := net.SplitHostPort(u.Host)
if err != nil {
u.Host = netutil.JoinHostPort(u.Host, port)
return
}
}
}
// logBegin logs the start of DNS request resolution. It should be called right
// before dialing the connection to the upstream. n is the [network] that will
// be used to send the request.
func logBegin(l *slog.Logger, addr string, n network, req *dns.Msg) {
var qtype dns.Type
var qname string
if len(req.Question) != 0 {
qtype = dns.Type(req.Question[0].Qtype)
qname = req.Question[0].Name
}
l.Debug("sending request", "addr", addr, "proto", n, "qtype", qtype, "qname", qname)
}
// logFinish logs the end of DNS request resolution. It should be called right
// after receiving the response from the upstream or the failing action. n is
// the [network] that was used to send the request.
func logFinish(l *slog.Logger, addr string, n network, err error) {
lvl := slog.LevelDebug
status := "ok"
if err != nil {
status = err.Error()
if isTimeout(err) {
// Notify user about the timeout.
lvl = slog.LevelError
}
}
l.Log(context.TODO(), lvl, "response received", "addr", addr, "proto", n, "status", status)
}
// isTimeout returns true if err is a timeout error.
//
// TODO(e.burkov): Move to golibs.
func isTimeout(err error) (ok bool) {
var netErr net.Error
switch {
case
errors.Is(err, context.Canceled),
errors.Is(err, context.DeadlineExceeded),
errors.Is(err, os.ErrDeadlineExceeded):
return true
case errors.As(err, &netErr):
return netErr.Timeout()
default:
return false
}
}
// DialerInitializer returns the handler that it creates.
type DialerInitializer func() (handler bootstrap.DialHandler, err error)
// newDialerInitializer creates an initializer of the dialer that will dial the
// addresses resolved from u using opts.
func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer) {
var l *slog.Logger
if opts.Logger != nil {
l = opts.Logger.With(slogutil.KeyPrefix, "bootstrap")
} else {
l = slog.Default()
}
// TODO(e.burkov): Add netutil.IsValidIPPortString.
if _, err := netip.ParseAddrPort(u.Host); err == nil {
// Don't resolve the address of the server since it's already an IP.
handler := bootstrap.NewDialContext(opts.Timeout, l, u.Host)
return func() (h bootstrap.DialHandler, dialerErr error) {
return handler, nil
}
}
boot := opts.Bootstrap
if boot == nil {
// Use the default resolver for bootstrapping.
boot = net.DefaultResolver
}
return func() (h bootstrap.DialHandler, err error) {
return bootstrap.ResolveDialContext(u, opts.Timeout, boot, opts.PreferIPv6, l)
}
}
070701000000A8000081A4000000000000000000000001679A649F00004DA1000000000000000000000000000000000000003300000000dnsproxy-0.75.0/upstream/upstream_internal_test.gopackage upstream
import (
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/netip"
"net/url"
"sync"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/ameshkov/dnsstamps"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(ameshkov): Make tests here not depend on external servers.
// TODO(d.kolyshev): Remove this after migrating dnscrypt to slog.
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
func TestUpstream_bootstrapTimeout(t *testing.T) {
t.Parallel()
const (
timeout = 100 * time.Millisecond
count = 10
)
// Test listener that never accepts connections to emulate faulty bootstrap.
udpListener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, udpListener.Close)
rslv, err := NewUpstreamResolver(udpListener.LocalAddr().String(), &Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: timeout,
})
require.NoError(t, err)
// Create an upstream that uses this faulty bootstrap.
u, err := AddressToUpstream("tls://random-domain-name", &Options{
Logger: slogutil.NewDiscardLogger(),
Bootstrap: NewCachingResolver(rslv),
Timeout: timeout,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
ch := make(chan int, count)
abort := make(chan string, 1)
for i := range count {
go func(idx int) {
t.Logf("Start %d", idx)
req := createTestMessage()
start := time.Now()
_, rErr := u.Exchange(req)
elapsed := time.Since(start)
if rErr == nil {
// Must not happen since bootstrap server cannot work.
abort <- fmt.Sprintf("the upstream must have timed out: %v", rErr)
}
// Check that the test didn't take too much time compared to the
// configured timeout. The actual elapsed time may be higher than
// the timeout due to the execution environment, 3 is an arbitrarily
// chosen multiplier to account for that.
if elapsed > 3*timeout {
abort <- fmt.Sprintf(
"exchange took more time than the configured timeout: %s",
elapsed,
)
}
t.Logf("Finished %d", idx)
ch <- idx
}(i)
}
for range count {
select {
case res := <-ch:
t.Logf("Got result from %d", res)
case msg := <-abort:
t.Fatalf("Aborted from the goroutine: %s", msg)
case <-time.After(timeout * 10):
t.Fatalf("No response in time")
}
}
}
func TestUpstreams(t *testing.T) {
t.Parallel()
const upsTimeout = 10 * time.Second
l := slogutil.NewDiscardLogger()
googleRslv, err := NewUpstreamResolver("8.8.8.8:53", &Options{
Logger: l,
Timeout: upsTimeout,
})
require.NoError(t, err)
cloudflareRslv, err := NewUpstreamResolver("1.0.0.1:53", &Options{
Logger: l,
Timeout: upsTimeout,
})
require.NoError(t, err)
googleBoot := NewCachingResolver(googleRslv)
cloudflareBoot := NewCachingResolver(cloudflareRslv)
upstreams := []struct {
bootstrap Resolver
address string
}{{
bootstrap: googleBoot,
address: "8.8.8.8:53",
}, {
bootstrap: nil,
address: "1.1.1.1",
}, {
bootstrap: cloudflareBoot,
address: "1.1.1.1",
}, {
bootstrap: nil,
address: "tcp://1.1.1.1:53",
}, {
bootstrap: nil,
address: "94.140.14.14:5353",
}, {
bootstrap: nil,
address: "tls://1.1.1.1",
}, {
bootstrap: nil,
address: "tls://9.9.9.9:853",
}, {
bootstrap: googleBoot,
address: "tls://dns.adguard.com",
}, {
bootstrap: googleBoot,
address: "tls://dns.adguard.com:853",
}, {
bootstrap: googleBoot,
address: "tls://dns.adguard.com:853",
}, {
bootstrap: nil,
address: "tls://one.one.one.one",
}, {
bootstrap: googleBoot,
address: "https://1dot1dot1dot1.cloudflare-dns.com/dns-query",
}, {
bootstrap: nil,
address: "https://dns.google/dns-query",
}, {
bootstrap: nil,
address: "https://doh.opendns.com/dns-query",
}, {
// AdGuard DNS (DNSCrypt)
bootstrap: nil,
address: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
}, {
// AdGuard Family (DNSCrypt)
bootstrap: googleBoot,
address: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNTo1NDQzILgxXdexS27jIKRw3C7Wsao5jMnlhvhdRUXWuMm1AFq6ITIuZG5zY3J5cHQuZmFtaWx5Lm5zMS5hZGd1YXJkLmNvbQ",
}, {
// Cloudflare DNS (DNS-over-HTTPS)
bootstrap: googleBoot,
address: "sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk",
}, {
// Google (Plain)
bootstrap: nil,
address: "sdns://AAcAAAAAAAAABzguOC44Ljg",
}, {
// AdGuard DNS (DNS-over-TLS)
bootstrap: googleBoot,
address: "sdns://AwAAAAAAAAAAAAAPZG5zLmFkZ3VhcmQuY29t",
}, {
// AdGuard DNS (DNS-over-QUIC)
bootstrap: googleBoot,
address: "sdns://BAcAAAAAAAAAAAAXZG5zLmFkZ3VhcmQtZG5zLmNvbTo3ODQ",
}, {
// Cloudflare DNS (DNS-over-HTTPS)
bootstrap: nil,
address: "https://1.1.1.1/dns-query",
}, {
// AdGuard DNS (DNS-over-QUIC)
bootstrap: googleBoot,
address: "quic://dns.adguard-dns.com",
}, {
// Google DNS (HTTP3)
bootstrap: nil,
address: "h3://dns.google/dns-query",
}}
for _, test := range upstreams {
t.Run(test.address, func(t *testing.T) {
t.Parallel()
u, upsErr := AddressToUpstream(
test.address,
&Options{Logger: l, Bootstrap: test.bootstrap, Timeout: upsTimeout},
)
require.NoErrorf(t, upsErr, "failed to generate upstream from address %s", test.address)
testutil.CleanupAndRequireSuccess(t, u.Close)
checkUpstream(t, u, test.address)
})
}
}
func TestAddressToUpstream(t *testing.T) {
cloudflareRslv, err := NewUpstreamResolver("1.1.1.1", nil)
require.NoError(t, err)
opt := &Options{
Logger: slogutil.NewDiscardLogger(),
Bootstrap: NewCachingResolver(cloudflareRslv),
}
testCases := []struct {
addr string
opt *Options
want string
}{{
addr: "1.1.1.1",
opt: nil,
want: "1.1.1.1:53",
}, {
addr: "1.1.1.1:5353",
opt: nil,
want: "1.1.1.1:5353",
}, {
addr: "one:5353",
opt: nil,
want: "one:5353",
}, {
addr: "one.one.one.one",
opt: nil,
want: "one.one.one.one:53",
}, {
addr: "udp://one.one.one.one",
opt: nil,
want: "one.one.one.one:53",
}, {
addr: "tcp://one.one.one.one",
opt: opt,
want: "tcp://one.one.one.one:53",
}, {
addr: "tls://one.one.one.one",
opt: opt,
want: "tls://one.one.one.one:853",
}, {
addr: "https://one.one.one.one",
opt: opt,
want: "https://one.one.one.one:443",
}, {
addr: "h3://one.one.one.one",
opt: opt,
want: "https://one.one.one.one:443",
}, {
addr: "::ffff:1.1.1.1",
opt: nil,
want: "[::ffff:1.1.1.1]:53",
}, {
addr: "https://[2606:4700:4700::1111]/dns-query",
opt: nil,
want: "https://[2606:4700:4700::1111]:443/dns-query",
}, {
addr: "https://[2606:4700:4700::1111]:443/dns-query",
opt: nil,
want: "https://[2606:4700:4700::1111]:443/dns-query",
}}
for _, tc := range testCases {
t.Run(tc.addr, func(t *testing.T) {
u, upsErr := AddressToUpstream(tc.addr, tc.opt)
require.NoError(t, upsErr)
testutil.CleanupAndRequireSuccess(t, u.Close)
assert.Equal(t, tc.want, u.Address())
})
}
}
func TestAddressToUpstream_bads(t *testing.T) {
testCases := []struct {
addr string
wantErrMsg string
}{{
addr: "asdf://1.1.1.1",
wantErrMsg: "unsupported url scheme: asdf",
}, {
addr: "12345.1.1.1:1234567",
wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
`value out of range`,
}, {
addr: ":1234567",
wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
`value out of range`,
}, {
addr: "host:",
wantErrMsg: `invalid port : strconv.ParseUint: parsing "": invalid syntax`,
}, {
addr: ":53",
wantErrMsg: `invalid address : bad domain name "": domain name is empty`,
}, {
addr: "!!!",
wantErrMsg: `invalid address !!!: bad domain name "!!!": bad top-level domain name ` +
`label "!!!": bad top-level domain name label rune '!'`,
}, {
addr: "123",
wantErrMsg: `invalid address 123: bad domain name "123": bad top-level domain name ` +
`label "123": all octets are numeric`,
}, {
addr: "tcp://12345.1.1.1:1234567",
wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
`value out of range`,
}, {
addr: "tcp://:1234567",
wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
`value out of range`,
}, {
addr: "tcp://host:",
wantErrMsg: `invalid port : strconv.ParseUint: parsing "": invalid syntax`,
}, {
addr: "tcp://:53",
wantErrMsg: `invalid address : bad domain name "": domain name is empty`,
}, {
addr: "tcp://!!!",
wantErrMsg: `invalid address !!!: bad domain name "!!!": bad top-level domain name ` +
`label "!!!": bad top-level domain name label rune '!'`,
}, {
addr: "tcp://123",
wantErrMsg: `invalid address 123: bad domain name "123": bad top-level domain name ` +
`label "123": all octets are numeric`,
}}
for _, tc := range testCases {
t.Run(tc.addr, func(t *testing.T) {
_, err := AddressToUpstream(tc.addr, nil)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
func TestUpstreamDoTBootstrap(t *testing.T) {
t.Parallel()
upstreams := []struct {
address string
bootstrap string
}{{
address: "tls://one.one.one.one/",
bootstrap: "tls://1.1.1.1",
}, {
address: "tls://one.one.one.one/",
bootstrap: "https://1.1.1.1/dns-query",
}, {
address: "tls://one.one.one.one/",
// Cisco OpenDNS
bootstrap: "sdns://AQAAAAAAAAAADjIwOC42Ny4yMjAuMjIwILc1EUAgbyJdPivYItf9aR6hwzzI1maNDL4Ev6vKQ_t5GzIuZG5zY3J5cHQtY2VydC5vcGVuZG5zLmNvbQ",
}}
for _, tc := range upstreams {
t.Run(tc.address, func(t *testing.T) {
rslv, err := NewUpstreamResolver(tc.bootstrap, &Options{
Logger: slogutil.NewDiscardLogger(),
Timeout: timeout,
})
require.NoError(t, err)
u, err := AddressToUpstream(tc.address, &Options{
Logger: slogutil.NewDiscardLogger(),
Bootstrap: NewCachingResolver(rslv),
Timeout: timeout,
})
require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address)
testutil.CleanupAndRequireSuccess(t, u.Close)
checkUpstream(t, u, tc.address)
})
}
}
// Test for DoH and DoT upstreams with two bootstraps (only one is valid)
func TestUpstreamsInvalidBootstrap(t *testing.T) {
t.Parallel()
upstreams := []struct {
address string
bootstrap []string
}{{
address: "tls://dns.adguard.com",
bootstrap: []string{"1.1.1.1:555", "8.8.8.8:53"},
}, {
address: "tls://dns.adguard.com:853",
bootstrap: []string{"1.0.0.1", "8.8.8.8:535"},
}, {
address: "https://1dot1dot1dot1.cloudflare-dns.com/dns-query",
bootstrap: []string{"8.8.8.1", "1.0.0.1"},
}, {
address: "https://doh.opendns.com:443/dns-query",
bootstrap: []string{"1.2.3.4:79", "8.8.8.8:53"},
}, {
// Cloudflare DNS (DoH)
address: "sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk",
bootstrap: []string{"8.8.8.8:53", "8.8.8.1:53"},
}, {
// AdGuard DNS (DNS-over-TLS)
address: "sdns://AwAAAAAAAAAAAAAPZG5zLmFkZ3VhcmQuY29t",
bootstrap: []string{"1.2.3.4:55", "8.8.8.8"},
}}
l := slogutil.NewDiscardLogger()
for _, tc := range upstreams {
t.Run(tc.address, func(t *testing.T) {
t.Parallel()
var rslv ConsequentResolver
for _, b := range tc.bootstrap {
r, err := NewUpstreamResolver(b, &Options{
Logger: l,
Timeout: timeout,
})
require.NoError(t, err)
rslv = append(rslv, NewCachingResolver(r))
}
u, err := AddressToUpstream(tc.address, &Options{
Logger: l,
Bootstrap: rslv,
Timeout: timeout,
})
require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address)
testutil.CleanupAndRequireSuccess(t, u.Close)
checkUpstream(t, u, tc.address)
})
}
t.Run("bad_bootstrap", func(t *testing.T) {
_, err := NewUpstreamResolver("asdfasdf", nil)
assert.Error(t, err) // bad bootstrap "asdfasdf"
})
}
func TestAddressToUpstream_StaticResolver(t *testing.T) {
t.Parallel()
h := func(w dns.ResponseWriter, m *dns.Msg) {
require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(m)))
}
dotSrv := startDoTServer(t, h)
dohSrv := startDoHServer(t, testDoHServerOptions{})
_, dohPort, err := net.SplitHostPort(dohSrv.addr)
require.NoError(t, err)
badResolver := &UpstreamResolver{Upstream: nil}
dotStamp := (&dnsstamps.ServerStamp{
ServerAddrStr: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(dotSrv.port)).String(),
Proto: dnsstamps.StampProtoTypeTLS,
ProviderName: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(dotSrv.port)).String(),
}).String()
dohStamp := (&dnsstamps.ServerStamp{
ServerAddrStr: dohSrv.addr,
Proto: dnsstamps.StampProtoTypeDoH,
ProviderName: dohSrv.addr,
Path: "/dns-query",
}).String()
upstreams := []struct {
rslv Resolver
name string
address string
}{{
rslv: StaticResolver{netutil.IPv4Localhost()},
name: "dot",
address: fmt.Sprintf("tls://some.dns.server:%d", dotSrv.port),
}, {
rslv: StaticResolver{netutil.IPv4Localhost()},
name: "doh",
address: fmt.Sprintf("https://some.dns.server:%s/dns-query", dohPort),
}, {
rslv: badResolver,
name: "dot_stamp",
address: dotStamp,
}, {
rslv: badResolver,
name: "doh_stamp",
address: dohStamp,
}}
for _, tc := range upstreams {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
opts := &Options{
Logger: slogutil.NewDiscardLogger(),
Bootstrap: tc.rslv,
Timeout: timeout,
InsecureSkipVerify: true,
}
u, uErr := AddressToUpstream(tc.address, opts)
require.NoError(t, uErr)
testutil.CleanupAndRequireSuccess(t, u.Close)
assert.NotPanics(t, func() {
checkUpstream(t, u, tc.address)
})
})
}
}
func TestAddPort(t *testing.T) {
testCases := []struct {
name string
want string
host string
port uint16
}{{
name: "empty",
want: ":0",
host: "",
port: 0,
}, {
name: "hostname",
want: "example.org:53",
host: "example.org",
port: 53,
}, {
name: "ipv4",
want: "1.2.3.4:1",
host: "1.2.3.4",
port: 1,
}, {
name: "ipv6",
want: "[::1]:1",
host: "::1",
port: 1,
}, {
name: "ipv6_with_brackets",
want: "[::1]:1",
host: "[::1]",
port: 1,
}, {
name: "hostname_with_port",
want: "example.org:54",
host: "example.org:54",
port: 53,
}, {
name: "ipv4_with_port",
want: "1.2.3.4:2",
host: "1.2.3.4:2",
port: 1,
}, {
name: "ipv6_with_brackets_and_port",
want: "[::1]:2",
host: "[::1]:2",
port: 1,
}}
for _, tc := range testCases {
u := &url.URL{
Host: tc.host,
}
t.Run(tc.name, func(t *testing.T) {
addPort(u, tc.port)
assert.Equal(t, tc.want, u.Host)
})
}
}
// checkUpstream sends a test message to the upstream and checks the result.
func checkUpstream(t *testing.T, u Upstream, addr string) {
t.Helper()
req := createTestMessage()
reply, err := u.Exchange(req)
require.NoErrorf(t, err, "couldn't talk to upstream %s", addr)
requireResponse(t, req, reply)
}
// checkRaceCondition runs several goroutines in parallel and each of them calls
// checkUpstream several times.
func checkRaceCondition(u Upstream) {
wg := sync.WaitGroup{}
// The number of requests to run in every goroutine.
reqCount := 10
// The overall number of goroutines to run.
goroutinesCount := 3
makeRequests := func() {
defer wg.Done()
for range reqCount {
req := createTestMessage()
// Ignore exchange errors here, the point is to check for races.
_, _ = u.Exchange(req)
}
}
wg.Add(goroutinesCount)
for range goroutinesCount {
go makeRequests()
}
wg.Wait()
}
// createTestMessage creates a *dns.Msg that we use for tests and that we then
// check with requireResponse.
func createTestMessage() (m *dns.Msg) {
return createHostTestMessage("google-public-dns-a.google.com")
}
// respondToTestMessage crafts a *dns.Msg response to a message created by
// createTestMessage.
func respondToTestMessage(m *dns.Msg) (resp *dns.Msg) {
resp = &dns.Msg{}
resp.SetReply(m)
resp.Answer = append(resp.Answer, &dns.A{
A: net.IPv4(8, 8, 8, 8),
Hdr: dns.RR_Header{
Name: "google-public-dns-a.google.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 100,
},
})
return resp
}
// createHostTestMessage creates a *dns.Msg with A request for the specified
// host name.
func createHostTestMessage(host string) (req *dns.Msg) {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: dns.Fqdn(host),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
}
// requireResponse validates that the *dns.Msg is a valid response to the
// message created by createTestMessage.
func requireResponse(t require.TestingT, req, reply *dns.Msg) {
require.NotNil(t, reply)
require.Lenf(t, reply.Answer, 1, "wrong number of answers: %d", len(reply.Answer))
require.Equal(t, req.Id, reply.Id)
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "wrong answer type: %v", reply.Answer[0])
require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}
// createServerTLSConfig creates a test server TLS configuration. It returns
// a *tls.Config that can be used for both the server and the client and the
// root certificate pem-encoded.
// TODO(ameshkov): start using rootCAs in tests instead of InsecureVerify.
func createServerTLSConfig(
tb testing.TB,
tlsServerName string,
) (tlsConfig *tls.Config, rootCAs *x509.CertPool) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(tb, err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(tb, err)
notBefore := time.Now()
notAfter := notBefore.Add(5 * 365 * time.Hour * 24)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"AdGuard Tests"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
ipAddress := net.ParseIP(tlsServerName)
if ipAddress != nil {
template.IPAddresses = append(template.IPAddresses, ipAddress)
} else {
template.DNSNames = append(template.DNSNames, tlsServerName)
}
derBytes, err := x509.CreateCertificate(
rand.Reader,
&template,
&template,
publicKey(privateKey),
privateKey,
)
require.NoError(tb, err)
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPem := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
},
)
cert, err := tls.X509KeyPair(certPem, keyPem)
require.NoError(tb, err)
rootCAs = x509.NewCertPool()
rootCAs.AppendCertsFromPEM(certPem)
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
ServerName: tlsServerName,
RootCAs: rootCAs,
MinVersion: tls.VersionTLS12,
}
return tlsConfig, rootCAs
}
// publicKey extracts the public key from the specified private key.
func publicKey(priv any) (pub any) {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}
07070100000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000B00000000TRAILER!!!1347 blocks