File dnsproxy-0.75.0.obscpio of Package dnsproxy

07070100000000000081A4000000000000000000000001679A649F00000080000000000000000000000000000000000000001D00000000dnsproxy-0.75.0/.codecov.ymlcoverage:
  status:
    project:
      default:
        target: 40%
        threshold: null
    patch: false
    changes: false
07070100000001000081A4000000000000000000000001679A649F00000049000000000000000000000000000000000000001E00000000dnsproxy-0.75.0/.dockerignore# Ignore everything except for explicitly allowed stuff.
*
!build/docker
07070100000002000081A4000000000000000000000001679A649F00000011000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/.gitattributesvendor/** binary
07070100000003000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001800000000dnsproxy-0.75.0/.github07070100000004000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002200000000dnsproxy-0.75.0/.github/workflows07070100000005000081A4000000000000000000000001679A649F00000A62000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/.github/workflows/build.yamlname: Build

'env':
  'GO_VERSION': '1.23.5'

'on':
  'push':
    'tags':
      - 'v*'
    'branches':
      - '*'
  'pull_request':

jobs:
  tests:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os:
          - windows-latest
          - macos-latest
          - ubuntu-latest
    steps:
      - uses: actions/checkout@master
      - uses: actions/setup-go@v2
        with:
          go-version: '${{ env.GO_VERSION }}'
      - name: Run tests
        env:
          CI: "1"
        run: |-
          make test
      - name: Upload coverage
        uses: codecov/codecov-action@v1
        if: "success() && matrix.os == 'ubuntu-latest'"
        with:
          token: ${{ secrets.CODECOV_TOKEN }}
          file: ./coverage.txt

  build:
    needs:
      - tests
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@master
      - uses: actions/setup-go@v2
        with:
          go-version: '${{ env.GO_VERSION }}'
      - name: Build release
        run: |-
          set -e -u -x

          RELEASE_VERSION="${GITHUB_REF##*/}"
          if [[ "${RELEASE_VERSION}" != v* ]]; then RELEASE_VERSION='dev'; fi
          echo "RELEASE_VERSION=\"${RELEASE_VERSION}\"" >> $GITHUB_ENV

          make VERBOSE=1 VERSION="${RELEASE_VERSION}" release

          ls -l build/dnsproxy-*
      - name: Create release
        if: startsWith(github.ref, 'refs/tags/v')
        id: create_release
        uses: actions/create-release@v1
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          tag_name: ${{ github.ref }}
          release_name: Release ${{ github.ref }}
          draft: false
          prerelease: false
      - name: Upload
        if: startsWith(github.ref, 'refs/tags/v')
        uses: xresloader/upload-to-github-release@v1.3.12
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          file: "build/dnsproxy-*.tar.gz;build/dnsproxy-*.zip"
          tags: true
          draft: false

  notify:
    needs:
      - build
    if:
      ${{ always() &&
        (
          github.event_name == 'push' ||
          github.event.pull_request.head.repo.full_name == github.repository
        )
      }}
    runs-on: ubuntu-latest
    steps:
      - name: Conclusion
        uses: technote-space/workflow-conclusion-action@v1
      - name: Send Slack notif
        uses: 8398a7/action-slack@v3
        with:
          status: ${{ env.WORKFLOW_CONCLUSION }}
          fields: workflow, repo, message, commit, author, eventName,ref
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
          SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
07070100000006000081A4000000000000000000000001679A649F0000090E000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/.github/workflows/docker.yml'name': Docker

'env':
  'GO_VERSION': '1.23.5'

'on':
  'push':
    'tags':
      - 'v*'
    # Builds from the master branch will be pushed with the `dev` tag.
    'branches':
      - 'master'

'jobs':
  'docker':
    'runs-on': 'ubuntu-latest'
    'steps':
      - 'name': 'Checkout'
        'uses': 'actions/checkout@v3'
        'with':
          'fetch-depth': 0
      - 'name': 'Set up Go'
        'uses': 'actions/setup-go@v3'
        'with':
          'go-version': '${{ env.GO_VERSION }}'
      - 'name': 'Set up Go modules cache'
        'uses': 'actions/cache@v2'
        'with':
          'path': '~/go/pkg/mod'
          'key': "${{ runner.os }}-go-${{ hashFiles('go.sum') }}"
          'restore-keys': '${{ runner.os }}-go-'
      - 'name': 'Set up QEMU'
        'uses': 'docker/setup-qemu-action@v1'
      - 'name': 'Set up Docker Buildx'
        'uses': 'docker/setup-buildx-action@v1'
      - 'name': 'Publish to Docker Hub'
        'env':
          'DOCKER_USER': ${{ secrets.DOCKER_USER }}
          'DOCKER_PASSWORD': ${{ secrets.DOCKER_PASSWORD }}
        'run': |-
          set -e -u -x

          RELEASE_VERSION="${GITHUB_REF##*/}"
          if [[ "${RELEASE_VERSION}" != v* ]]; then RELEASE_VERSION='dev'; fi
          echo "RELEASE_VERSION=\"${RELEASE_VERSION}\"" >> $GITHUB_ENV

          docker login \
            -u="${DOCKER_USER}" \
            -p="${DOCKER_PASSWORD}"

          make \
            VERSION="${RELEASE_VERSION}" \
            DOCKER_IMAGE_NAME="adguard/dnsproxy" \
            DOCKER_OUTPUT="type=image,name=adguard/dnsproxy,push=true" \
            VERBOSE="1" \
            docker

  'notify':
    'needs':
      - 'docker'
    'if':
      ${{ always() &&
      (
      github.event_name == 'push' ||
      github.event.pull_request.head.repo.full_name == github.repository
      )
      }}
    'runs-on': ubuntu-latest
    'steps':
      - 'name': Conclusion
        'uses': technote-space/workflow-conclusion-action@v1
      - 'name': Send Slack notif
        'uses': 8398a7/action-slack@v3
        'with':
          'status': ${{ env.WORKFLOW_CONCLUSION }}
          'fields': workflow, repo, message, commit, author, eventName,ref
        'env':
          'GITHUB_TOKEN': ${{ secrets.GITHUB_TOKEN }}
          'SLACK_WEBHOOK_URL': ${{ secrets.SLACK_WEBHOOK_URL }}
07070100000007000081A4000000000000000000000001679A649F00000593000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/.github/workflows/lint.yaml'name': 'lint'

'env':
  'GO_VERSION': '1.23.5'

'on':
  'push':
    'tags':
      - 'v*'
    'branches':
      - '*'
  'pull_request':

'jobs':
  'go-lint':
    'runs-on': 'ubuntu-latest'
    'steps':
      - 'uses': 'actions/checkout@v2'
      - 'name': 'Set up Go'
        'uses': 'actions/setup-go@v3'
        'with':
          'go-version': '${{ env.GO_VERSION }}'
      - 'name': 'run-lint'
        'run': >
          make go-deps go-tools go-lint

  'notify':
    'needs':
      - 'go-lint'
    # Secrets are not passed to workflows that are triggered by a pull request
    # from a fork.
    #
    # Use always() to signal to the runner that this job must run even if the
    # previous ones failed.
    'if':
      ${{
      always() &&
      github.repository_owner == 'AdguardTeam' &&
      (
      github.event_name == 'push' ||
      github.event.pull_request.head.repo.full_name == github.repository
      )
      }}
    'runs-on': 'ubuntu-latest'
    'steps':
      - 'name': 'Conclusion'
        'uses': 'technote-space/workflow-conclusion-action@v1'
      - 'name': 'Send Slack notif'
        'uses': '8398a7/action-slack@v3'
        'with':
          'status': '${{ env.WORKFLOW_CONCLUSION }}'
          'fields': 'workflow, repo, message, commit, author, eventName, ref'
        'env':
          'GITHUB_TOKEN': '${{ secrets.GITHUB_TOKEN }}'
          'SLACK_WEBHOOK_URL': '${{ secrets.SLACK_WEBHOOK_URL }}'
07070100000008000081A4000000000000000000000001679A649F000001D3000000000000000000000000000000000000001B00000000dnsproxy-0.75.0/.gitignore# Please, DO NOT put your text editors' temporary files here.  The more are
# added, the harder it gets to maintain and manage projects' gitignores.  Put
# them into your global gitignore file instead.
#
# See https://stackoverflow.com/a/7335487/1892060.
#
# Only build, run, and test outputs here.  Sorted.  With negations at the
# bottom to make sure they take effect.
*.out
*.test
/bin/
build
dnsproxy
dnsproxy.exe
example.crt
example.key
coverage.txt
config.yaml
07070100000009000081A4000000000000000000000001679A649F00002C57000000000000000000000000000000000000001800000000dnsproxy-0.75.0/LICENSE
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2020 Adguard Software Ltd

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
0707010000000A000081A4000000000000000000000001679A649F00000BDC000000000000000000000000000000000000001900000000dnsproxy-0.75.0/Makefile# Keep the Makefile POSIX-compliant.  We currently allow hyphens in
# target names, but that may change in the future.
#
# See https://pubs.opengroup.org/onlinepubs/9799919799/utilities/make.html.
.POSIX:

# This comment is used to simplify checking local copies of the
# Makefile.  Bump this number every time a significant change is made to
# this Makefile.
#
# AdGuard-Project-Version: 9

# Don't name these macros "GO" etc., because GNU Make apparently makes
# them exported environment variables with the literal value of
# "${GO:-go}" and so on, which is not what we need.  Use a dot in the
# name to make sure that users don't have an environment variable with
# the same name.
#
# See https://unix.stackexchange.com/q/646255/105635.
GO.MACRO = $${GO:-go}
VERBOSE.MACRO = $${VERBOSE:-0}

BRANCH = $${BRANCH:-$$(git rev-parse --abbrev-ref HEAD)}
DIST_DIR = build
GOAMD64 = v1
GOPROXY = https://proxy.golang.org|direct
GOTOOLCHAIN = go1.23.5
GOTELEMETRY = off
OUT = dnsproxy
RACE = 0
REVISION = $${REVISION:-$$(git rev-parse --short HEAD)}
VERSION = 0

ENV = env\
	BRANCH="$(BRANCH)"\
	DIST_DIR='$(DIST_DIR)'\
	GO="$(GO.MACRO)"\
	GOAMD64='$(GOAMD64)'\
	GOPROXY='$(GOPROXY)'\
	GOTELEMETRY='$(GOTELEMETRY)'\
	GOTOOLCHAIN='$(GOTOOLCHAIN)'\
	OUT='$(OUT)'\
	PATH="$${PWD}/bin:$$("$(GO.MACRO)" env GOPATH)/bin:$${PATH}"\
	RACE='$(RACE)'\
	REVISION="$(REVISION)"\
	VERBOSE="$(VERBOSE.MACRO)"\
	VERSION="$(VERSION)"\

# Keep the line above blank.

ENV_MISC = env\
	PATH="$${PWD}/bin:$$("$(GO.MACRO)" env GOPATH)/bin:$${PATH}"\
	VERBOSE="$(VERBOSE.MACRO)"\

# Keep the line above blank.

# Keep this target first, so that a naked make invocation triggers a
# full build.
build: go-deps go-build

init: ; git config core.hooksPath ./scripts/hooks

test: go-test

go-build:     ; $(ENV)          "$(SHELL)" ./scripts/make/go-build.sh
go-deps:      ; $(ENV)          "$(SHELL)" ./scripts/make/go-deps.sh
go-env:       ; $(ENV)          "$(GO.MACRO)" env
go-lint:      ; $(ENV)          "$(SHELL)" ./scripts/make/go-lint.sh
go-test:      ; $(ENV) RACE='1' "$(SHELL)" ./scripts/make/go-test.sh
go-tools:     ; $(ENV)          "$(SHELL)" ./scripts/make/go-tools.sh
go-upd-tools: ; $(ENV)          "$(SHELL)" ./scripts/make/go-upd-tools.sh

go-check: go-tools go-lint go-test

# A quick check to make sure that all operating systems relevant to the
# development of the project can be typechecked and built successfully.
go-os-check:
	$(ENV) GOOS='darwin'  "$(GO.MACRO)" vet ./...
	$(ENV) GOOS='freebsd' "$(GO.MACRO)" vet ./...
	$(ENV) GOOS='openbsd' "$(GO.MACRO)" vet ./...
	$(ENV) GOOS='linux'   "$(GO.MACRO)" vet ./...
	$(ENV) GOOS='windows' "$(GO.MACRO)" vet ./...

txt-lint: ; $(ENV) "$(SHELL)" ./scripts/make/txt-lint.sh

md-lint:  ; $(ENV_MISC) "$(SHELL)" ./scripts/make/md-lint.sh
sh-lint:  ; $(ENV_MISC) "$(SHELL)" ./scripts/make/sh-lint.sh

clean:   ; $(ENV) $(GO.MACRO) clean && rm -f -r '$(DIST_DIR)'

release: clean
	$(ENV) "$(SHELL)" ./scripts/make/build-release.sh

docker: release
	$(ENV) "$(SHELL)" ./scripts/make/build-docker.sh
0707010000000B000081A4000000000000000000000001679A649F00004ABE000000000000000000000000000000000000001A00000000dnsproxy-0.75.0/README.md[![Code Coverage](https://img.shields.io/codecov/c/github/AdguardTeam/dnsproxy/master.svg)](https://codecov.io/github/AdguardTeam/dnsproxy?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/AdguardTeam/dnsproxy)](https://goreportcard.com/report/AdguardTeam/dnsproxy)
[![Go Doc](https://godoc.org/github.com/AdguardTeam/dnsproxy?status.svg)](https://godoc.org/github.com/AdguardTeam/dnsproxy)

# DNS Proxy <!-- omit in toc -->

A simple DNS proxy server that supports all existing DNS protocols including
`DNS-over-TLS`, `DNS-over-HTTPS`, `DNSCrypt`, and `DNS-over-QUIC`. Moreover,
it can work as a `DNS-over-HTTPS`, `DNS-over-TLS` or `DNS-over-QUIC` server.

- [How to install](#how-to-install)
- [How to build](#how-to-build)
- [Usage](#usage)
- [Examples](#examples)
  - [Simple options](#simple-options)
  - [Encrypted upstreams](#encrypted-upstreams)
  - [Encrypted DNS server](#encrypted-dns-server)
  - [Additional features](#additional-features)
  - [DNS64 server](#dns64-server)
  - [Fastest addr + cache-min-ttl](#fastest-addr--cache-min-ttl)
  - [Specifying upstreams for domains](#specifying-upstreams-for-domains)
  - [EDNS Client Subnet](#edns-client-subnet)
  - [Bogus NXDomain](#bogus-nxdomain)

## How to install

There are several options how to install `dnsproxy`.

1. Grab the binary for your device/OS from the [Releases][releases] page.
2. Use the [official Docker image][docker].
3. Build it yourself (see the instruction below).

[releases]: https://github.com/AdguardTeam/dnsproxy/releases
[docker]: https://hub.docker.com/r/adguard/dnsproxy

## How to build

You will need Go v1.21 or later.

```shell
$ make build
```

## Usage

```
Usage:
  dnsproxy [OPTIONS]

Application Options:
      --config-path=               yaml configuration file. Minimal working configuration in config.yaml.dist. Options passed through command
                                   line will override the ones from this file.
  -o, --output=                    Path to the log file. If not set, write to stdout.
  -c, --tls-crt=                   Path to a file with the certificate chain
  -k, --tls-key=                   Path to a file with the private key
      --https-server-name=         Set the Server header for the responses from the HTTPS server. (default: dnsproxy)
      --https-userinfo=            If set, all DoH queries are required to have this basic authentication information.
  -g, --dnscrypt-config=           Path to a file with DNSCrypt configuration. You can generate one using https://github.com/ameshkov/dnscrypt
      --edns-addr=                 Send EDNS Client Address
      --upstream-mode=             Defines the upstreams logic mode, possible values: load_balance, parallel, fastest_addr (default:
                                   load_balance)
  -l, --listen=                    Listening addresses
  -p, --port=                      Listening ports. Zero value disables TCP and UDP listeners
  -s, --https-port=                Listening ports for DNS-over-HTTPS
  -t, --tls-port=                  Listening ports for DNS-over-TLS
  -q, --quic-port=                 Listening ports for DNS-over-QUIC
  -y, --dnscrypt-port=             Listening ports for DNSCrypt
  -u, --upstream=                  An upstream to be used (can be specified multiple times). You can also specify path to a file with the
                                   list of servers
  -b, --bootstrap=                 Bootstrap DNS for DoH and DoT, can be specified multiple times (default: use system-provided)
  -f, --fallback=                  Fallback resolvers to use when regular ones are unavailable, can be specified multiple times. You can also
                                   specify path to a file with the list of servers
      --private-rdns-upstream=     Private DNS upstreams to use for reverse DNS lookups of private addresses, can be specified multiple times
      --dns64-prefix=              Prefix used to handle DNS64. If not specified, dnsproxy uses the 'Well-Known Prefix' 64:ff9b::.  Can be
                                   specified multiple times
      --private-subnets=           Private subnets to use for reverse DNS lookups of private addresses
      --bogus-nxdomain=            Transform the responses containing at least a single IP that matches specified addresses and CIDRs into
                                   NXDOMAIN.  Can be specified multiple times.
      --hosts-files=               List of paths to the hosts files, can be specified multiple times
      --timeout=                   Timeout for outbound DNS queries to remote upstream servers in a human-readable form (default: 10s)
      --cache-min-ttl=             Minimum TTL value for DNS entries, in seconds. Capped at 3600. Artificially extending TTLs should only be
                                   done with careful consideration.
      --cache-max-ttl=             Maximum TTL value for DNS entries, in seconds.
      --cache-size=                Cache size (in bytes). Default: 64k
  -r, --ratelimit=                 Ratelimit (requests per second)
      --ratelimit-subnet-len-ipv4= Ratelimit subnet length for IPv4. (default: 24)
      --ratelimit-subnet-len-ipv6= Ratelimit subnet length for IPv6. (default: 56)
      --udp-buf-size=              Set the size of the UDP buffer in bytes. A value <= 0 will use the system default.
      --max-go-routines=           Set the maximum number of go routines. A zero value will not not set a maximum.
      --tls-min-version=           Minimum TLS version, for example 1.0
      --tls-max-version=           Maximum TLS version, for example 1.3
      --pprof                      If present, exposes pprof information on localhost:6060.
      --version                    Prints the program version
  -v, --verbose                    Verbose output (optional)
      --insecure                   Disable secure TLS certificate validation
      --ipv6-disabled              If specified, all AAAA requests will be replied with NoError RCode and empty answer
      --http3                      Enable HTTP/3 support
      --cache-optimistic           If specified, optimistic DNS cache is enabled
      --cache                      If specified, DNS cache is enabled
      --refuse-any                 If specified, refuse ANY requests
      --edns                       Use EDNS Client Subnet extension
      --dns64                      If specified, dnsproxy will act as a DNS64 server
      --use-private-rdns           If specified, use private upstreams for reverse DNS lookups of private addresses
      --hosts-file-enabled=        If specified, use hosts files for resolving (default: true)

Help Options:
  -h, --help                       Show this help message
```

## Examples

### Simple options

Runs a DNS proxy on `0.0.0.0:53` with a single upstream - Google DNS.
```shell
./dnsproxy -u 8.8.8.8:53
```

The same proxy with verbose logging enabled writing it to the file `log.txt`.
```shell
./dnsproxy -u 8.8.8.8:53 -v -o log.txt
```

Runs a DNS proxy on `127.0.0.1:5353` with multiple upstreams.
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8:53 -u 1.1.1.1:53
```

Listen on multiple interfaces and ports:
```shell
./dnsproxy -l 127.0.0.1 -l 192.168.1.10 -p 5353 -p 5354 -u 1.1.1.1
```

The plain DNS upstream server may be specified in several ways:

 -  With a plain IP address:
    ```shell
    ./dnsproxy -l 127.0.0.1 -u 8.8.8.8:53
    ```

 -  With a hostname or plain IP address and the `udp://` scheme:
    ```shell
    ./dnsproxy -l 127.0.0.1 -u udp://dns.google -u udp://1.1.1.1
    ```

 -  With a hostname or plain IP address and the `tcp://` scheme to force using
    TCP:
    ```shell
    ./dnsproxy -l 127.0.0.1 -u tcp://dns.google -u tcp://1.1.1.1
    ```

### Encrypted upstreams

DNS-over-TLS upstream:
```shell
./dnsproxy -u tls://dns.adguard.com
```

DNS-over-HTTPS upstream with specified bootstrap DNS:
```shell
./dnsproxy -u https://dns.adguard.com/dns-query -b 1.1.1.1:53
```

DNS-over-QUIC upstream:
```shell
./dnsproxy -u quic://dns.adguard.com
```

DNS-over-HTTPS upstream with enabled HTTP/3 support (chooses it if it's faster):
```shell
./dnsproxy -u https://dns.google/dns-query --http3
```

DNS-over-HTTPS upstream with forced HTTP/3 (no fallback to other protocol):
```shell
./dnsproxy -u h3://dns.google/dns-query
```

DNSCrypt upstream ([DNS Stamp](https://dnscrypt.info/stamps) of AdGuard DNS):
```shell
./dnsproxy -u sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20
```

DNS-over-HTTPS upstream ([DNS Stamp](https://dnscrypt.info/stamps) of Cloudflare DNS):
```shell
./dnsproxy -u sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk
```

DNS-over-TLS upstream with two fallback servers (to be used when the main upstream is not available):
```shell
./dnsproxy -u tls://dns.adguard.com -f 8.8.8.8:53 -f 1.1.1.1:53
```

### Encrypted DNS server

Runs a DNS-over-TLS proxy on `127.0.0.1:853`.
```shell
./dnsproxy -l 127.0.0.1 --tls-port=853 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```

Runs a DNS-over-HTTPS proxy on `127.0.0.1:443`.
```shell
./dnsproxy -l 127.0.0.1 --https-port=443 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```

Runs a DNS-over-HTTPS proxy on `127.0.0.1:443` with HTTP/3 support.
```shell
./dnsproxy -l 127.0.0.1 --https-port=443 --http3 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```

Runs a DNS-over-QUIC proxy on `127.0.0.1:853`.
```shell
./dnsproxy -l 127.0.0.1 --quic-port=853 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```

Runs a DNSCrypt proxy on `127.0.0.1:443`.

```shell
./dnsproxy -l 127.0.0.1 --dnscrypt-config=./dnscrypt-config.yaml --dnscrypt-port=443 --upstream=8.8.8.8:53 -p 0
```

> Please note that in order to run a DNSCrypt proxy, you need to obtain DNSCrypt configuration first. You can use https://github.com/ameshkov/dnscrypt command-line tool to do that with a command like this `./dnscrypt generate --provider-name=2.dnscrypt-cert.example.org --out=dnscrypt-config.yaml`

### Additional features

Runs a DNS proxy on `0.0.0.0:53` with rate limit set to `10 rps`, enabled DNS cache, and that refuses type=ANY requests.
```shell
./dnsproxy -u 8.8.8.8:53 -r 10 --cache --refuse-any
```

Runs a DNS proxy on 127.0.0.1:5353 with multiple upstreams and enable parallel queries to all configured upstream servers.
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8:53 -u 1.1.1.1:53 -u tls://dns.adguard.com --upstream-mode parallel
```

Loads upstreams list from a file.
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u ./upstreams.txt
```

### DNS64 server

`dnsproxy` is capable of working as a DNS64 server.

> **What is DNS64/NAT64**
> This is a mechanism of providing IPv6 access to IPv4. Using a NAT64 gateway
> with IPv4-IPv6 translation capability lets IPv6-only clients connect to
> IPv4-only services via synthetic IPv6 addresses starting with a prefix that
> routes them to the NAT64 gateway. DNS64 is a DNS service that returns AAAA
> records with these synthetic IPv6 addresses for IPv4-only destinations
> (with A but not AAAA records in the DNS). This lets IPv6-only clients use
> NAT64 gateways without any other configuration.

See also [RFC 6147](https://datatracker.ietf.org/doc/html/rfc6147).

Enables DNS64 with the default [Well-Known Prefix][wkp]:
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8 --use-private-rdns --private-rdns-upstream=127.0.0.1 --dns64
```

You can also specify any number of custom DNS64 prefixes:
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8 --use-private-rdns --private-rdns-upstream=127.0.0.1 --dns64 --dns64-prefix=64:ffff:: --dns64-prefix=32:ffff::
```

Note that only the first specified prefix will be used for synthesis.

PTR queries for addresses within the specified ranges or the
[Well-Known one][wkp] could only be answered with locally appropriate data, so
dnsproxy will route those to the local upstream servers.  Those should be
specified and enabled if DNS64 is enabled.

[wkp]: https://datatracker.ietf.org/doc/html/rfc6052#section-2.1

### Fastest addr + cache-min-ttl

This option would be useful to the users with problematic network connection.
In this mode, `dnsproxy` would detect the fastest IP address among all that were
returned, and it will return only it.

Additionally, for those with problematic network connection, it makes sense to
override `cache-min-ttl`.  In this case, `dnsproxy` will make sure that DNS
responses are cached for at least the specified amount of time.

It makes sense to run it with multiple upstream servers only.

Run a DNS proxy with two upstreams, min-TTL set to 10 minutes, fastest address
detection is enabled:
```
./dnsproxy -u 8.8.8.8 -u 1.1.1.1 --cache --cache-min-ttl=600 --upstream-mode=fastest_addr
```

 who run `dnsproxy` with multiple upstreams

### Specifying upstreams for domains

You can specify upstreams that will be used for a specific domain(s). We use the
dnsmasq-like syntax, decorating domains with brackets (see `--server`
[description][server-description]).

**Syntax:** `[/[domain1][/../domainN]/]upstreamString`

Where `upstreamString` is one or many upstreams separated by space (e.g.
`1.1.1.1` or `1.1.1.1 2.2.2.2`).

If one or more domains are specified, that upstream (`upstreamString`) is used
only for those domains. Usually, it is used for private nameservers. For
instance, if you have a nameserver on your network which deals with
`xxx.internal.local` at `192.168.0.1` then you can specify
`[/internal.local/]192.168.0.1`, and dnsproxy will send all queries to that
nameserver. Everything else will be sent to the default upstreams (which are
mandatory!).

1. An empty domain specification, `//` has the special meaning of "unqualified
   names only", which will be used to resolve names with a single label in them,
   or with exactly two labels in case of `DS` requests.
2. More specific domains take precedence over less specific domains, so:
   `--upstream=[/host.com/]1.2.3.4 --upstream=[/www.host.com/]2.3.4.5` will send
   queries for `*.host.com` to `1.2.3.4`, except `*.www.host.com`, which will go
   to `2.3.4.5`.
3. The special server address `#` means, "use the common servers", so:
   `--upstream=[/host.com/]1.2.3.4 --upstream=[/www.host.com/]#` will send
   queries for `*.host.com` to `1.2.3.4`, except `*.www.host.com` which will be
   forwarded as usual.
4. The wildcard `*` has special meaning of "any sub-domain", so:
   `--upstream=[/*.host.com/]1.2.3.4` will send queries for `*.host.com` to
   `1.2.3.4`, but `host.com` will be forwarded to default upstreams.

**Examples**

Sends requests for `*.local` domains to `192.168.0.1:53`. Other requests are
sent to `8.8.8.8:53`:

```sh
./dnsproxy\
    -u "8.8.8.8:53"\
    -u "[/local/]192.168.0.1:53"
```

Sends requests for `*.host.com` to `1.1.1.1:53` except for `*.maps.host.com`
which are sent to `8.8.8.8:53` (along with other requests):

```sh
./dnsproxy\
    -u "8.8.8.8:53"\
    -u "[/host.com/]1.1.1.1:53"\
    -u "[/maps.host.com/]#"
```

Sends requests for `*.host.com` to `1.1.1.1:53` except for `host.com` which is
sent to `9.9.9.10:53`, and all other requests are sent to `8.8.8.8:53`:

```sh
./dnsproxy\
    -u "8.8.8.8:53"\
    -u "[/host.com/]9.9.9.10:53"\
    -u "[/*.host.com/]1.1.1.1:53"
```

Sends requests for `com` (and its subdomains) to `1.2.3.4:53`, requests for
other top-level domains to `1.1.1.1:53`, and all other requests to `8.8.8.8:53`:

```sh
./dnsproxy\
    -u "8.8.8.8:53"\
    -u "[//]1.1.1.1:53"\
    -u "[/com/]1.2.3.4:53"
```

### Specifying private rDNS upstreams

You can specify upstreams that will be used for reverse DNS requests of type PTR
for private addresses.  Same applies to the authority requests of types SOA and
NS.  The set of private addresses is defined by the `--private-rdns-upstream`,
and the set from [RFC 6303][rfc6303] is used by default.

The additional requirement to the domains specified for upstreams is to be
`in-addr.arpa`, `ip6.arpa`, or its subdomain.  Addresses encoded in the domains
should also be private.

**Examples**

Sends queries for `*.168.192.in-addr.arpa` to `192.168.1.2`, if requested by
client from `192.168.0.0/16` subnet.  Other queries answered with `NXDOMAIN`:

```sh
./dnsproxy\
    -l "0.0.0.0"\
    -u "8.8.8.8"\
    --use-private-rdns\
    --private-subnets="192.168.0.0/16"
    --private-rdns-upstream="192.168.1.2"\
```

Sends queries for `*.in-addr.arpa` to `192.168.1.2`, `*.ip6.arpa` to `fe80::1`,
if requested by client within the default [RFC 6303][rfc6303] subnet set.  Other
queries answered with `NXDOMAIN`:

```sh
./dnsproxy\
    -l "0.0.0.0"\
    -u 8.8.8.8\
    --use-private-rdns\
    --private-rdns-upstream="192.168.1.2"\
    --private-rdns-upstream="[/ip6.arpa/]fe80::1"
```

[rfc6303]: https://datatracker.ietf.org/doc/html/rfc6303
[server-description]: http://www.thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html


### EDNS Client Subnet

To enable support for EDNS Client Subnet extension you should run dnsproxy with `--edns` flag:

```
./dnsproxy -u 8.8.8.8:53 --edns
```

Now if you connect to the proxy from the Internet - it will pass through your original IP address's prefix to the upstream server.  This way the upstream server may respond with IP addresses of the servers that are located near you to minimize latency.

If you want to use EDNS CS feature when you're connecting to the proxy from a local network, you need to set `--edns-addr=PUBLIC_IP` argument:

```
./dnsproxy -u 8.8.8.8:53 --edns --edns-addr=72.72.72.72
```

Now even if your IP address is 192.168.0.1 and it's not a public IP, the proxy will pass through 72.72.72.72 to the upstream server.

### Bogus NXDomain

This option is similar to dnsmasq `bogus-nxdomain`.  `dnsproxy` will transform
responses that contain at least a single IP address which is also specified by
the option into `NXDOMAIN`. Can be specified multiple times.

In the example below, we use AdGuard DNS server that returns `0.0.0.0` for
blocked domains, and transform them to `NXDOMAIN`.

```
./dnsproxy -u 94.140.14.14:53 --bogus-nxdomain=0.0.0.0
```

CIDR ranges are supported as well.  The following will respond with `NXDOMAIN`
instead of responses containing any IP from `192.168.0.0`-`192.168.255.255`:

```
./dnsproxy -u 192.168.0.15:53 --bogus-nxdomain=192.168.0.0/16
```

### Basic Auth for DoH

By setting the `--https-userinfo` option you can use `dnsproxy` as a DoH proxy
with basic authentication requirements.

For example:

```sh
./dnsproxy\
    --https-port='443'\
    --https-userinfo='user:p4ssw0rd'\
    --tls-crt='…/my.crt'\
    --tls-key='…/my.key'\
    -u '94.140.14.14:53'
```

This configuration will only allow DoH queries that contain an `Authorization`
header containing the BasicAuth credentials for user `user` with password
`p4ssw0rd`.

Add `-p 0` if you also want to disable plain-DNS handling and make `dnsproxy`
only serve DoH with Basic Auth checking.
0707010000000C000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001D00000000dnsproxy-0.75.0/bamboo-specs0707010000000D000081A4000000000000000000000001679A649F000009AB000000000000000000000000000000000000002900000000dnsproxy-0.75.0/bamboo-specs/bamboo.yaml---
'version': 2
'plan':
    'project-key': 'GO'
    'key': 'DNSPROXY'
    'name': 'dnsproxy - Build and run tests'
'variables':
    'dockerFpm': 'alanfranz/fpm-within-docker:ubuntu-bionic'
    # When there is a patch release of Go available, set this property to an
    # exact patch version as opposed to a minor one to make sure that this exact
    # version is actually used and not whatever the docker daemon on the CI has
    # cached a few months ago.
    'dockerGo': 'golang:1.23.5'
    'maintainer': 'Adguard Go Team'
    'name': 'dnsproxy'

'stages':
# TODO(e.burkov):  Add separate lint stage for texts.
  - 'Lint':
      'manual': false
      'final': false
      'jobs':
        - 'Lint'
  - 'Test':
      'manual': false
      'final': false
      'jobs':
        - 'Test'

'Lint':
    'docker':
        'image': '${bamboo.dockerGo}'
        'volumes':
            '${system.GO_CACHE_DIR}': '${bamboo.cacheGo}'
            '${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}'
    'key': 'LINT'
    'other':
        'clean-working-dir': true
    'requirements':
      - 'adg-docker': true
    'tasks':
      - 'checkout':
             'force-clean-build': true
      - 'script':
              'interpreter': 'SHELL'
              'scripts':
                - |
                  set -e -f -u -x

                  make VERBOSE=1 GOMAXPROCS=1 go-tools go-lint

'Test':
    'docker':
        'image': '${bamboo.dockerGo}'
        'volumes':
            '${system.GO_CACHE_DIR}': '${bamboo.cacheGo}'
            '${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}'
    'key': 'TEST'
    'other':
        'clean-working-dir': true
    'requirements':
      - 'adg-docker': true
    'tasks':
      - 'checkout':
            'force-clean-build': true
      - 'script':
            'interpreter': 'SHELL'
            # Projects that have go-bench and/or go-fuzz targets should add them
            # here as well.
            'scripts':
              - |
                set -e -f -u -x

                make VERBOSE=1 go-deps go-test

'branches':
    'create': 'for-pull-request'
    'delete':
        'after-deleted-days': 1
        'after-inactive-days': 5
    'link-to-jira': true

'notifications':
  - 'events':
      - 'plan-status-changed'
    'recipients':
      - 'webhook':
            'name': 'Build webhook'
            'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo'

'labels': []

'other':
    'concurrent-build-plugin': 'system-default'
0707010000000E000081A4000000000000000000000001679A649F00000230000000000000000000000000000000000000002100000000dnsproxy-0.75.0/config.yaml.dist# This is the yaml configuration file for dnsproxy with minimal working
# configuration, all the options available can be seen with ./dnsproxy --help.
# To use it within dnsproxy specify the --config-path=/<path-to-config.yaml>
# option.  Any other command-line options specified will override the values
# from the config file.
---
bootstrap:
  - "8.8.8.8:53"
listen-addrs:
  - "0.0.0.0"
listen-ports:
  - 53
max-go-routines: 0
ratelimit: 0
ratelimit-subnet-len-ipv4: 24
ratelimit-subnet-len-ipv6: 64
udp-buf-size: 0
upstream:
  - "1.1.1.1:53"
timeout: '10s'
0707010000000F000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001700000000dnsproxy-0.75.0/docker07070100000010000081A4000000000000000000000001679A649F0000072A000000000000000000000000000000000000002200000000dnsproxy-0.75.0/docker/Dockerfile# A docker file for scripts/make/build-docker.sh.

FROM alpine:3.18

ARG BUILD_DATE
ARG VERSION
ARG VCS_REF

LABEL\
	maintainer="AdGuard Team <devteam@adguard.com>" \
	org.opencontainers.image.authors="AdGuard Team <devteam@adguard.com>" \
	org.opencontainers.image.created=$BUILD_DATE \
	org.opencontainers.image.description="Simple DNS proxy with DoH, DoT, DoQ and DNSCrypt support" \
	org.opencontainers.image.documentation="https://github.com/AdguardTeam/dnsproxy" \
	org.opencontainers.image.licenses="Apache-2.0" \
	org.opencontainers.image.revision=$VCS_REF \
	org.opencontainers.image.source="https://github.com/AdguardTeam/dnsproxy" \
	org.opencontainers.image.title="dnsproxy" \
	org.opencontainers.image.url="https://github.com/AdguardTeam/dnsproxy" \
	org.opencontainers.image.vendor="AdGuard" \
	org.opencontainers.image.version=$VERSION

# Update certificates.
RUN apk --no-cache add ca-certificates libcap tzdata && \
	mkdir -p /opt/dnsproxy && chown -R nobody: /opt/dnsproxy

ARG DIST_DIR
ARG TARGETARCH
ARG TARGETOS
ARG TARGETVARIANT

COPY --chown=nobody:nogroup\
	./${DIST_DIR}/docker/dnsproxy_${TARGETOS}_${TARGETARCH}_${TARGETVARIANT}\
	/opt/dnsproxy/dnsproxy
COPY --chown=nobody:nogroup\
    ./${DIST_DIR}/docker/config.yaml\
    /opt/dnsproxy/config.yaml

RUN setcap 'cap_net_bind_service=+eip' /opt/dnsproxy/dnsproxy

# 53     : TCP, UDP : DNS
# 80     : TCP      : HTTP
# 443    : TCP, UDP : HTTPS, DNS-over-HTTPS (incl. HTTP/3), DNSCrypt (main)
# 853    : TCP, UDP : DNS-over-TLS, DNS-over-QUIC
# 5443   : TCP, UDP : DNSCrypt (alt)
# 6060   : TCP      : HTTP (pprof)
EXPOSE 53/tcp 53/udp \
       80/tcp \
       443/tcp 443/udp \
       853/tcp 853/udp \
       5443/tcp 5443/udp \
       6060/tcp

WORKDIR /opt/dnsproxy

ENTRYPOINT ["/opt/dnsproxy/dnsproxy"]
CMD ["--config-path=/opt/dnsproxy/config.yaml"]
07070100000011000081A4000000000000000000000001679A649F00000490000000000000000000000000000000000000002100000000dnsproxy-0.75.0/docker/README.md# DNS Proxy

A simple DNS proxy server that supports all existing DNS protocols including
`DNS-over-TLS`, `DNS-over-HTTPS`, `DNSCrypt`, and `DNS-over-QUIC`. Moreover,
it can work as a `DNS-over-HTTPS`, `DNS-over-TLS` or `DNS-over-QUIC` server.

Learn more about dnsproxy and its full capabilities in
its [Github repo][dnsproxy].

[dnsproxy]: https://github.com/AdguardTeam/dnsproxy

## Quick start

### Pull the Docker image

This command will pull the latest stable version:

```shell
docker pull adguard/dnsproxy
```

### Run the container

Run the container with the default configuration (see `config.yaml.dist` in the
repository) and expose DNS ports.

```shell
docker run --name dnsproxy \
  -p 53:53/tcp -p 53:53/udp \
  adguard/dnsproxy
```

Run the container with command-line args configuration and expose DNS ports.

```shell
docker run --name dnsproxy_google_dns \
  -p 53:53/tcp -p 53:53/udp \
  adguard/dnsproxy \
  -u 8.8.8.8:53
```

Run the container with a configuration file and expose DNS ports.

```shell
docker run --name dnsproxy_google_dns \
  -p 53:53/tcp -p 53:53/udp \
  -v $PWD/config.yaml:/opt/dnsproxy/config.yaml \
  adguard/dnsproxy
```
07070100000012000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001700000000dnsproxy-0.75.0/fastip07070100000013000081A4000000000000000000000001679A649F00000A54000000000000000000000000000000000000002000000000dnsproxy-0.75.0/fastip/cache.gopackage fastip

import (
	"encoding/binary"
	"net/netip"
	"time"
)

const (
	// fastestAddrCacheTTLSec is the cache TTL for IP addresses.
	fastestAddrCacheTTLSec = 10 * 60
)

// cacheEntry represents an item that will be stored in the cache.
//
// TODO(e.burkov): Rewrite the cache using zero-values instead of storing
// useless boolean as an integer.
type cacheEntry struct {
	// status is 1 if the item is timed out.
	status      int
	latencyMsec uint
}

// packCacheEntry packs the cache entry and the TTL to bytes in the following
// order:
//
//   - expire   [4]byte  (Unix time, seconds),
//   - status   byte     (0 for ok, 1 for timed out),
//   - latency  [2]byte  (milliseconds).
func packCacheEntry(ent *cacheEntry, ttl uint32) (d []byte) {
	expire := uint32(time.Now().Unix()) + ttl

	d = make([]byte, 4+1+2)
	binary.BigEndian.PutUint32(d, expire)
	i := 4

	d[i] = byte(ent.status)
	i++

	binary.BigEndian.PutUint16(d[i:], uint16(ent.latencyMsec))
	// i += 2

	return d
}

// unpackCacheEntry unpacks bytes to cache entry and checks TTL, if the record
// is expired returns nil.
func unpackCacheEntry(data []byte) (ent *cacheEntry) {
	now := time.Now().Unix()
	expire := binary.BigEndian.Uint32(data[:4])
	if int64(expire) <= now {
		return nil
	}

	ent = &cacheEntry{}
	i := 4

	ent.status = int(data[i])
	i++

	ent.latencyMsec = uint(binary.BigEndian.Uint16(data[i:]))
	// i += 2

	return ent
}

// cacheFind finds entry in the cache for the given IP address.  Returns nil if
// nothing is found or if the record is expired.
func (f *FastestAddr) cacheFind(ip netip.Addr) (ent *cacheEntry) {
	val := f.ipCache.Get(ip.AsSlice())
	if val == nil {
		return nil
	}

	return unpackCacheEntry(val)
}

// cacheAddFailure stores unsuccessful attempt in cache.
func (f *FastestAddr) cacheAddFailure(ip netip.Addr) {
	ent := cacheEntry{
		status: 1,
	}

	f.ipCacheLock.Lock()
	defer f.ipCacheLock.Unlock()

	if f.cacheFind(ip) == nil {
		f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
	}
}

// cacheAddSuccessful stores a successful ping result in the cache.  Replaces
// previous result if our latency is lower.
func (f *FastestAddr) cacheAddSuccessful(ip netip.Addr, latency uint) {
	ent := cacheEntry{
		latencyMsec: latency,
	}

	f.ipCacheLock.Lock()
	defer f.ipCacheLock.Unlock()

	entCached := f.cacheFind(ip)
	if entCached == nil || entCached.status != 0 || entCached.latencyMsec > latency {
		f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
	}
}

// cacheAdd adds a new entry to the cache.
func (f *FastestAddr) cacheAdd(ent *cacheEntry, ip netip.Addr, ttl uint32) {
	val := packCacheEntry(ent, ttl)
	f.ipCache.Set(ip.AsSlice(), val)
}
07070100000014000081A4000000000000000000000001679A649F00000912000000000000000000000000000000000000002500000000dnsproxy-0.75.0/fastip/cache_test.gopackage fastip

import (
	"net"
	"net/netip"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/stretchr/testify/assert"
)

func TestCacheAdd(t *testing.T) {
	f := New(&Config{Logger: slogutil.NewDiscardLogger()})
	ent := cacheEntry{
		status:      0,
		latencyMsec: 111,
	}

	ip := netip.MustParseAddr("1.1.1.1")
	f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)

	// check that it's there
	assert.NotNil(t, f.cacheFind(ip))
}

func TestCacheTtl(t *testing.T) {
	f := New(&Config{Logger: slogutil.NewDiscardLogger()})
	ent := cacheEntry{
		status:      0,
		latencyMsec: 111,
	}

	ip := netip.MustParseAddr("1.1.1.1")
	f.cacheAdd(&ent, ip, 1)

	// check that it's there
	assert.NotNil(t, f.cacheFind(ip))

	// wait for more than one second
	time.Sleep(time.Millisecond * 1001)

	// check that now it returns nil
	assert.Nil(t, f.cacheFind(ip))
}

func TestCacheAddSuccessfulOverwrite(t *testing.T) {
	f := New(&Config{Logger: slogutil.NewDiscardLogger()})

	ip := netip.MustParseAddr("1.1.1.1")
	f.cacheAddFailure(ip)

	// check that it's there
	ent := f.cacheFind(ip)
	assert.NotNil(t, ent)
	assert.Equal(t, 1, ent.status)

	// check that it will overwrite existing rec
	f.cacheAddSuccessful(ip, 11)

	// check that it's there now
	ent = f.cacheFind(ip)
	assert.NotNil(t, ent)
	assert.Equal(t, 0, ent.status)
	assert.Equal(t, uint(11), ent.latencyMsec)
}

func TestCacheAddFailureNoOverwrite(t *testing.T) {
	f := New(&Config{Logger: slogutil.NewDiscardLogger()})

	ip := netip.MustParseAddr("1.1.1.1")
	f.cacheAddSuccessful(ip, 11)

	// check that it's there
	ent := f.cacheFind(ip)
	assert.NotNil(t, ent)
	assert.Equal(t, 0, ent.status)

	// check that it will overwrite existing rec
	f.cacheAddFailure(ip)

	// check that the old record is still there
	ent = f.cacheFind(ip)
	assert.NotNil(t, ent)
	assert.Equal(t, 0, ent.status)
	assert.Equal(t, uint(11), ent.latencyMsec)
}

// TODO(ameshkov): Actually test something.
func TestCache(_ *testing.T) {
	f := New(&Config{Logger: slogutil.NewDiscardLogger()})
	ent := cacheEntry{
		status:      0,
		latencyMsec: 111,
	}

	val := packCacheEntry(&ent, 1)
	f.ipCache.Set(net.ParseIP("1.1.1.1").To4(), val)
	ent = cacheEntry{
		status:      0,
		latencyMsec: 222,
	}

	f.cacheAdd(&ent, netip.MustParseAddr("2.2.2.2"), fastestAddrCacheTTLSec)
}
07070100000015000081A4000000000000000000000001679A649F00001618000000000000000000000000000000000000002200000000dnsproxy-0.75.0/fastip/fastest.go// Package fastip implements the algorithm that allows to query multiple
// resolvers, ping all IP addresses that were returned, and return the fastest
// one among them.
package fastip

import (
	"log/slog"
	"net"
	"net/netip"
	"strings"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/cache"
	"github.com/AdguardTeam/golibs/container"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

// LogPrefix is a prefix for logging.
const LogPrefix = "fastip"

// DefaultPingWaitTimeout is the default period of time for waiting ping
// operations to finish.
const DefaultPingWaitTimeout = 1 * time.Second

// FastestAddr provides methods to determine the fastest network addresses.
type FastestAddr struct {
	// logger is used for logging during the process.  It is never nil.
	logger *slog.Logger

	// pinger is the dialer with predefined timeout for pinging TCP connections.
	pinger *net.Dialer

	// ipCacheLock protects ipCache.
	ipCacheLock *sync.Mutex

	// ipCache caches fastest IP addresses.
	ipCache cache.Cache

	// pingPorts are the ports to ping on.
	pingPorts []uint

	// pingWaitTimeout is the timeout for waiting all the resolved addresses to
	// be pinged.  Any ping results received after that moment are cached, but
	// won't be used.
	pingWaitTimeout time.Duration
}

// NewFastestAddr initializes a new instance of *FastestAddr.
//
// Deprecated: Use [New] instead.
func NewFastestAddr() (f *FastestAddr) {
	return &FastestAddr{
		logger:      slog.Default().With(slogutil.KeyPrefix, LogPrefix),
		ipCacheLock: &sync.Mutex{},
		ipCache: cache.New(cache.Config{
			MaxSize:   64 * 1024,
			EnableLRU: true,
		}),
		pingPorts:       []uint{80, 443},
		pingWaitTimeout: DefaultPingWaitTimeout,
		pinger:          &net.Dialer{Timeout: pingTCPTimeout},
	}
}

// Config contains all the fields necessary for proxy configuration.
type Config struct {
	// Logger is used as the base logger for the service.  If nil,
	// [slog.Default] with [LogPrefix] is used.
	Logger *slog.Logger

	// PingWaitTimeout is the timeout for waiting all the resolved addresses to
	// be pinged.  Any ping results received after that moment are cached, but
	// won't be used.  If zero, [DefaultPingWaitTimeout] is used.
	PingWaitTimeout time.Duration
}

// New initializes a new instance of *FastestAddr.
func New(c *Config) (f *FastestAddr) {
	f = &FastestAddr{
		ipCacheLock: &sync.Mutex{},
		ipCache: cache.New(cache.Config{
			MaxSize:   64 * 1024,
			EnableLRU: true,
		}),
		pingPorts: []uint{80, 443},
		pinger:    &net.Dialer{Timeout: pingTCPTimeout},
	}

	if c.PingWaitTimeout > 0 {
		f.pingWaitTimeout = c.PingWaitTimeout
	} else {
		f.pingWaitTimeout = DefaultPingWaitTimeout
	}

	if c.Logger != nil {
		f.logger = c.Logger
	} else {
		f.logger = slog.Default().With(slogutil.KeyPrefix, LogPrefix)
	}

	return f
}

// ExchangeFastest queries each specified upstream and returns the response with
// the fastest IP address.  The fastest IP address is considered to be the first
// one successfully dialed and other addresses are removed from the answer.
func (f *FastestAddr) ExchangeFastest(
	req *dns.Msg,
	ups []upstream.Upstream,
) (resp *dns.Msg, u upstream.Upstream, err error) {
	replies, err := upstream.ExchangeAll(ups, req)
	if err != nil {
		return nil, nil, err
	}

	ipSet := container.NewMapSet[netip.Addr]()
	for _, r := range replies {
		for _, rr := range r.Resp.Answer {
			ip := ipFromRR(rr)
			if ip.IsValid() && !ip.IsUnspecified() {
				ipSet.Add(ip)
			}
		}
	}

	ips := ipSet.Values()
	host := strings.ToLower(req.Question[0].Name)
	if pingRes := f.pingAll(host, ips); pingRes != nil {
		return f.prepareReply(pingRes, replies)
	}

	f.logger.Debug("no fastest ip found, using the first response", "host", host)

	return replies[0].Resp, replies[0].Upstream, nil
}

// prepareReply converts replies into the DNS answer message according to res.
// The returned upstream is the one which replied with the fastest address.
func (f *FastestAddr) prepareReply(
	res *pingResult,
	replies []upstream.ExchangeAllResult,
) (resp *dns.Msg, u upstream.Upstream, err error) {
	ip := res.addrPort.Addr()
	for _, r := range replies {
		if hasInAns(r.Resp, ip) {
			resp = r.Resp
			u = r.Upstream

			break
		}
	}

	if resp == nil {
		f.logger.Error("found no replies, most likely this is a bug", "ip", ip)

		// TODO(d.kolyshev): Consider returning error?
		return replies[0].Resp, replies[0].Upstream, nil
	}

	filterResponseAnswer(resp, ip)

	return resp, u, nil
}

// filterResponseAnswer modifies the response message, it keeps only A and AAAA
// records with the given IP address.
func filterResponseAnswer(resp *dns.Msg, ip netip.Addr) {
	ans := make([]dns.RR, 0, len(resp.Answer))
	ipBytes := ip.AsSlice()
	for _, rr := range resp.Answer {
		switch addr := rr.(type) {
		case *dns.A:
			if addr.A.Equal(ipBytes) {
				ans = append(ans, rr)
			}
		case *dns.AAAA:
			if addr.AAAA.Equal(ipBytes) {
				ans = append(ans, rr)
			}
		default:
			ans = append(ans, rr)
		}
	}

	// Set new answer.
	resp.Answer = ans
}

// hasInAns returns true if m contains ip in its Answer section.
func hasInAns(m *dns.Msg, ip netip.Addr) (ok bool) {
	for _, rr := range m.Answer {
		respIP := ipFromRR(rr)
		if respIP == ip {
			return true
		}
	}

	return false
}

// ipFromRR returns the IP address from rr if any.
func ipFromRR(rr dns.RR) (ip netip.Addr) {
	switch rr := rr.(type) {
	case *dns.A:
		ip, _ = netutil.IPToAddr(rr.A, netutil.AddrFamilyIPv4)
	case *dns.AAAA:
		ip, _ = netutil.IPToAddr(rr.AAAA, netutil.AddrFamilyIPv6)
	}

	return ip
}
07070100000016000081A4000000000000000000000001679A649F000010DB000000000000000000000000000000000000002700000000dnsproxy-0.75.0/fastip/fastest_test.gopackage fastip

import (
	"net/netip"
	"testing"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestFastestAddr_ExchangeFastest(t *testing.T) {
	l := slogutil.NewDiscardLogger()

	t.Run("error", func(t *testing.T) {
		const errDesired errors.Error = "this is expected"

		u := &errUpstream{
			err: errDesired,
		}
		f := New(&Config{
			Logger:          l,
			PingWaitTimeout: DefaultPingWaitTimeout,
		})

		resp, up, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{u})
		require.Error(t, err)

		assert.ErrorIs(t, err, errDesired)
		assert.Nil(t, resp)
		assert.Nil(t, up)
	})

	t.Run("one_dead", func(t *testing.T) {
		port := listen(t, netip.IPv4Unspecified())

		f := New(&Config{
			Logger:          l,
			PingWaitTimeout: DefaultPingWaitTimeout,
		})
		f.pingPorts = []uint{port}

		// The alive IP is the just created local listener's address.  The dead
		// one is known as TEST-NET-1 which shouldn't be routed at all.  See
		// RFC-5737 (https://datatracker.ietf.org/doc/html/rfc5737).
		aliveAddr := netip.MustParseAddr("127.0.0.1")

		alive := &testAUpstream{
			recs: []*dns.A{newTestRec(t, aliveAddr)},
		}
		dead := &testAUpstream{
			recs: []*dns.A{newTestRec(t, netip.MustParseAddr("192.0.2.1"))},
		}

		rep, ups, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{dead, alive})
		require.NoError(t, err)

		assert.Equal(t, ups, alive)

		require.NotNil(t, rep)
		require.NotEmpty(t, rep.Answer)
		require.IsType(t, new(dns.A), rep.Answer[0])

		ip := rep.Answer[0].(*dns.A).A
		assert.Equal(t, aliveAddr.AsSlice(), []byte(ip))
	})

	t.Run("all_dead", func(t *testing.T) {
		f := New(&Config{
			Logger:          l,
			PingWaitTimeout: DefaultPingWaitTimeout,
		})
		f.pingPorts = []uint{getFreePort(t)}

		firstIP := netip.MustParseAddr("127.0.0.1")
		ups := &testAUpstream{
			recs: []*dns.A{
				newTestRec(t, firstIP),
				newTestRec(t, netip.MustParseAddr("127.0.0.2")),
				newTestRec(t, netip.MustParseAddr("127.0.0.3")),
			},
		}

		resp, _, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{ups})
		require.NoError(t, err)

		require.NotNil(t, resp)
		require.NotEmpty(t, resp.Answer)
		require.IsType(t, new(dns.A), resp.Answer[0])

		ip := resp.Answer[0].(*dns.A).A
		assert.Equal(t, firstIP.AsSlice(), []byte(ip))
	})
}

// testAUpstream is a mock err upstream structure for tests.
type errUpstream struct {
	err      error
	closeErr error
}

// Address implements the [upstream.Upstream] interface for *errUpstream.
func (u *errUpstream) Address() string {
	return "bad_upstream"
}

// Exchange implements the [upstream.Upstream] interface for *errUpstream.
func (u *errUpstream) Exchange(_ *dns.Msg) (*dns.Msg, error) {
	return nil, u.err
}

// Close implements the [upstream.Upstream] interface for *errUpstream.
func (u *errUpstream) Close() error {
	return u.closeErr
}

// testAUpstream is a mock A upstream structure for tests.
type testAUpstream struct {
	recs []*dns.A
}

// type check
var _ upstream.Upstream = (*testAUpstream)(nil)

// Exchange implements the [upstream.Upstream] interface for *testAUpstream.
func (u *testAUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
	resp = &dns.Msg{}
	resp.SetReply(m)

	for _, a := range u.recs {
		resp.Answer = append(resp.Answer, a)
	}

	return resp, nil
}

// Address implements the [upstream.Upstream] interface for *testAUpstream.
func (u *testAUpstream) Address() (addr string) {
	return ""
}

// Close implements the [upstream.Upstream] interface for *testAUpstream.
func (u *testAUpstream) Close() (err error) {
	return nil
}

// newTestRec returns a new test A record.
func newTestRec(t *testing.T, addr netip.Addr) (rr *dns.A) {
	return &dns.A{
		Hdr: dns.RR_Header{
			Rrtype: dns.TypeA,
			Name:   dns.Fqdn(t.Name()),
			Ttl:    60,
		},
		A: addr.AsSlice(),
	}
}

// newTestReq returns a new test A request.
func newTestReq(t *testing.T) (req *dns.Msg) {
	return &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id:               dns.Id(),
			RecursionDesired: true,
		},
		Question: []dns.Question{{
			Name:   dns.Fqdn(t.Name()),
			Qtype:  dns.TypeA,
			Qclass: dns.ClassINET,
		}},
	}
}
07070100000017000081A4000000000000000000000001679A649F00000F48000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/fastip/ping.gopackage fastip

import (
	"net/netip"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
)

// pingTCPTimeout is a TCP connection timeout.  It's higher than pingWaitTimeout
// since the slower connections will be cached anyway.
const pingTCPTimeout = 4 * time.Second

// pingResult is the result of dialing the address.
type pingResult struct {
	// addrPort is the address-port pair the result is related to.
	addrPort netip.AddrPort

	// latency is the duration of dialing process in milliseconds.
	latency uint

	// success is true when the dialing succeeded.
	success bool
}

// schedulePings returns the result with the fastest IP address from the cache,
// if it's found, and starts pinging other IPs which are not cached or outdated.
// Returns scheduled flag which indicates that some goroutines have been
// scheduled.
func (f *FastestAddr) schedulePings(
	resCh chan *pingResult,
	ips []netip.Addr,
	host string,
) (pr *pingResult, scheduled bool) {
	for _, ip := range ips {
		cached := f.cacheFind(ip)
		if cached == nil {
			scheduled = true
			for _, port := range f.pingPorts {
				go f.pingDoTCP(host, netip.AddrPortFrom(ip, uint16(port)), resCh)
			}

			continue
		}

		if cached.status == 0 && (pr == nil || cached.latencyMsec < pr.latency) {
			pr = &pingResult{
				addrPort: netip.AddrPortFrom(ip, 0),
				latency:  cached.latencyMsec,
				success:  true,
			}
		}
	}

	return pr, scheduled
}

// pingAll pings all ips concurrently and returns as soon as the fastest one is
// found or the timeout is exceeded.
func (f *FastestAddr) pingAll(host string, ips []netip.Addr) (pr *pingResult) {
	ipN := len(ips)
	switch ipN {
	case 0:
		return nil
	case 1:
		return &pingResult{
			addrPort: netip.AddrPortFrom(ips[0], 0),
			success:  true,
		}
	}

	resCh := make(chan *pingResult, ipN*len(f.pingPorts))
	pr, scheduled := f.schedulePings(resCh, ips, host)
	if !scheduled {
		if pr != nil {
			f.logger.Debug(
				"pinging all returns cached response",
				"host", host,
				"addr", pr.addrPort,
			)
		} else {
			f.logger.Debug("pinging all returns nothing", "host", host)
		}

		return pr
	}

	res := f.firstSuccessRes(resCh, host)
	if res == nil {
		// In case of timeout return cached or nil.
		return pr
	}

	if pr == nil || res.latency <= pr.latency {
		// Cache wasn't found or is worse than res.
		return res
	}

	// Return cached result.
	return pr
}

// firstSuccessRes waits and returns the first successful ping result or nil in
// case of timeout.
func (f *FastestAddr) firstSuccessRes(resCh chan *pingResult, host string) (res *pingResult) {
	after := time.After(f.pingWaitTimeout)
	for {
		select {
		case res = <-resCh:
			f.logger.Debug(
				"pinging all got result",
				"host", host,
				"addr", res.addrPort,
				"status", res.success,
			)

			if !res.success {
				continue
			}

			return res
		case <-after:
			f.logger.Debug("pinging all timed out", "host", host)

			return nil
		}
	}
}

// pingDoTCP sends the result of dialing the specified address into resCh.
func (f *FastestAddr) pingDoTCP(host string, addrPort netip.AddrPort, resCh chan *pingResult) {
	l := f.logger.With("host", host, "addr", addrPort)
	l.Debug("open tcp connection")

	start := time.Now()
	conn, err := f.pinger.Dial(bootstrap.NetworkTCP, addrPort.String())
	elapsed := time.Since(start)

	success := err == nil
	if success {
		if cErr := conn.Close(); cErr != nil {
			l.Debug("closing tcp connection", slogutil.KeyError, cErr)
		}
	}

	latency := uint(elapsed.Milliseconds())

	resCh <- &pingResult{
		addrPort: addrPort,
		latency:  latency,
		success:  success,
	}

	addr := addrPort.Addr().Unmap()
	if success {
		l.Debug("tcp ping success", "elapsed", elapsed)
		f.cacheAddSuccessful(addr, latency)
	} else {
		l.Debug("tcp ping failed to connect", "elapsed", elapsed, slogutil.KeyError, err)
		f.cacheAddFailure(addr)
	}
}
07070100000018000081A4000000000000000000000001679A649F00001610000000000000000000000000000000000000002400000000dnsproxy-0.75.0/fastip/ping_test.gopackage fastip

import (
	"net"
	"net/netip"
	"runtime"
	"sync"
	"syscall"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// unit is the convenient alias for struct{}.
type unit = struct{}

func TestFastestAddr_PingAll_timeout(t *testing.T) {
	t.Run("isolated", func(t *testing.T) {
		f := New(&Config{Logger: slogutil.NewDiscardLogger()})

		waitCh := make(chan unit)
		f.pinger.Control = func(_, _ string, _ syscall.RawConn) error {
			<-waitCh

			return nil
		}

		ip := netutil.IPv4Localhost()
		res := f.pingAll("", []netip.Addr{ip, ip})
		require.Nil(t, res)

		waitCh <- unit{}
	})

	t.Run("cached", func(t *testing.T) {
		f := New(&Config{Logger: slogutil.NewDiscardLogger()})

		const lat uint = 42

		ip1 := netutil.IPv4Localhost()
		ip2 := netip.MustParseAddr("127.0.0.2")
		f.cacheAddSuccessful(ip1, lat)

		waitCh := make(chan unit)
		f.pinger.Control = func(_, _ string, _ syscall.RawConn) error {
			<-waitCh

			return nil
		}

		res := f.pingAll("", []netip.Addr{ip1, ip2})
		require.NotNil(t, res)

		assert.True(t, res.success)
		assert.Equal(t, lat, res.latency)

		waitCh <- unit{}
	})
}

// assertCaching checks the cache of f for containing a connection to ip with
// the specified status.
func assertCaching(t *testing.T, f *FastestAddr, ip netip.Addr, status int) {
	t.Helper()

	const tickDur = pingTCPTimeout / 16

	assert.Eventually(t, func() bool {
		ce := f.cacheFind(ip)

		return ce != nil && ce.status == status
	}, pingTCPTimeout, tickDur)
}

func TestFastestAddr_PingAll_cache(t *testing.T) {
	ip := netutil.IPv4Localhost()

	t.Run("cached_failed", func(t *testing.T) {
		f := New(&Config{Logger: slogutil.NewDiscardLogger()})
		f.cacheAddFailure(ip)

		res := f.pingAll("", []netip.Addr{ip, ip})
		require.Nil(t, res)
	})

	t.Run("cached_successful", func(t *testing.T) {
		const lat uint = 1

		f := New(&Config{Logger: slogutil.NewDiscardLogger()})
		f.cacheAddSuccessful(ip, lat)

		res := f.pingAll("", []netip.Addr{ip, ip})
		require.NotNil(t, res)
		assert.True(t, res.success)
		assert.Equal(t, lat, res.latency)
	})

	t.Run("not_cached", func(t *testing.T) {
		listener, err := net.Listen("tcp", "127.0.0.1:0")
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, listener.Close)

		ip = netutil.IPv4Localhost()
		f := New(&Config{Logger: slogutil.NewDiscardLogger()})

		f.pingPorts = []uint{uint(listener.Addr().(*net.TCPAddr).Port)}
		ips := []netip.Addr{ip, ip}

		wg := &sync.WaitGroup{}
		wg.Add(len(ips) * len(f.pingPorts))

		f.pinger.Control = func(_, address string, _ syscall.RawConn) (err error) {
			hostport, err := netutil.ParseHostPort(address)
			require.NoError(t, err)

			assert.Equal(t, ip.String(), hostport.Host)
			assert.Contains(t, f.pingPorts, uint(hostport.Port))

			wg.Done()

			return nil
		}

		res := f.pingAll("", ips)
		require.NotNil(t, res)

		assert.True(t, res.success)
		assertCaching(t, f, ip, 0)

		wg.Wait()
	})
}

// listen is a helper function that creates a new listener on ip for t.
func listen(t *testing.T, ip netip.Addr) (port uint) {
	t.Helper()

	l, err := net.Listen("tcp", netip.AddrPortFrom(ip, 0).String())
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, l.Close)

	return uint(l.Addr().(*net.TCPAddr).Port)
}

func TestFastestAddr_PingAll(t *testing.T) {
	ip := netutil.IPv4Localhost()

	t.Run("single", func(t *testing.T) {
		f := New(&Config{Logger: slogutil.NewDiscardLogger()})
		res := f.pingAll("", []netip.Addr{ip})
		require.NotNil(t, res)

		assert.True(t, res.success)
		assert.Equal(t, ip, res.addrPort.Addr())
		// There was no ping so the port is zero.
		assert.Zero(t, res.addrPort.Port())

		// Nothing in the cache since there was no ping.
		ce := f.cacheFind(res.addrPort.Addr())
		require.Nil(t, ce)
	})

	t.Run("fastest", func(t *testing.T) {
		fastPort := listen(t, ip)
		slowPort := listen(t, ip)

		ctrlCh := make(chan unit, 1)

		f := New(&Config{Logger: slogutil.NewDiscardLogger()})
		f.pingPorts = []uint{
			fastPort,
			slowPort,
		}
		f.pinger.Control = func(_, address string, _ syscall.RawConn) error {
			addrPort := netip.MustParseAddrPort(address)
			require.Contains(t, []uint{fastPort, slowPort}, uint(addrPort.Port()))
			if addrPort.Port() == uint16(fastPort) {
				return nil
			}

			<-ctrlCh

			return nil
		}

		ips := []netip.Addr{ip, ip}
		res := f.pingAll("", ips)
		ctrlCh <- unit{}

		require.NotNil(t, res)

		assert.True(t, res.success)
		assert.Equal(t, ip, res.addrPort.Addr())
		assert.EqualValues(t, fastPort, res.addrPort.Port())

		assertCaching(t, f, ip, 0)
	})

	t.Run("zero", func(t *testing.T) {
		res := New(&Config{Logger: slogutil.NewDiscardLogger()}).pingAll("", nil)
		require.Nil(t, res)
	})

	t.Run("fail", func(t *testing.T) {
		port := getFreePort(t)

		f := New(&Config{Logger: slogutil.NewDiscardLogger()})
		f.pingPorts = []uint{port}

		res := f.pingAll("test", []netip.Addr{ip, ip})
		require.Nil(t, res)

		assertCaching(t, f, ip, 1)
	})
}

// getFreePort returns the port number no one listens on.
//
// TODO(e.burkov):  The logic is underwhelming.  Find a more accurate way.
func getFreePort(t *testing.T) (port uint) {
	t.Helper()

	l, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(t, err)

	port = uint(l.Addr().(*net.TCPAddr).Port)

	// Stop listening immediately.
	require.NoError(t, l.Close())

	// Sleeping for some time may be necessary on Windows.
	if runtime.GOOS == "windows" {
		time.Sleep(100 * time.Millisecond)
	}

	return port
}
07070100000019000081A4000000000000000000000001679A649F000005C4000000000000000000000000000000000000001700000000dnsproxy-0.75.0/go.modmodule github.com/AdguardTeam/dnsproxy

go 1.23.5

require (
	github.com/AdguardTeam/golibs v0.31.0
	github.com/ameshkov/dnscrypt/v2 v2.3.0
	github.com/ameshkov/dnsstamps v1.0.3
	github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0
	github.com/bluele/gcache v0.0.2
	github.com/miekg/dns v1.1.62
	github.com/patrickmn/go-cache v2.1.0+incompatible
	github.com/quic-go/quic-go v0.48.2
	github.com/stretchr/testify v1.10.0
	golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67
	golang.org/x/net v0.33.0
	golang.org/x/sys v0.28.0
	gonum.org/v1/gonum v0.15.1
	gopkg.in/yaml.v3 v3.0.1
)

require (
	github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
	github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect
	github.com/davecgh/go-spew v1.1.1 // indirect
	github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
	github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect
	github.com/kr/text v0.2.0 // indirect
	github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
	github.com/onsi/ginkgo/v2 v2.22.1 // indirect
	github.com/pmezard/go-difflib v1.0.0 // indirect
	github.com/quic-go/qpack v0.5.1 // indirect
	go.uber.org/mock v0.5.0 // indirect
	golang.org/x/crypto v0.31.0 // indirect
	golang.org/x/mod v0.22.0 // indirect
	golang.org/x/sync v0.10.0 // indirect
	golang.org/x/text v0.21.0 // indirect
	golang.org/x/tools v0.28.0 // indirect
	gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
)
0707010000001A000081A4000000000000000000000001679A649F00001A2B000000000000000000000000000000000000001700000000dnsproxy-0.75.0/go.sumgithub.com/AdguardTeam/golibs v0.31.0 h1:Z0oPfLTLw6iZmpE58dePy2Bel0MaX+lnDwtFEE5EmIo=
github.com/AdguardTeam/golibs v0.31.0/go.mod h1:wIkZ9o2UnppeW6/YD7yJB71dYbMhiuC1Fh/I2ElW7GQ=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us=
github.com/ameshkov/dnscrypt/v2 v2.3.0 h1:pDXDF7eFa6Lw+04C0hoMh8kCAQM8NwUdFEllSP2zNLs=
github.com/ameshkov/dnscrypt/v2 v2.3.0/go.mod h1:N5hDwgx2cNb4Ay7AhvOSKst+eUiOZ/vbKRO9qMpQttE=
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
github.com/ameshkov/dnsstamps v1.0.3/go.mod h1:Ii3eUu73dx4Vw5O4wjzmT5+lkCwovjzaEZZ4gKyIH5A=
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 h1:0b2vaepXIfMsG++IsjHiI2p4bxALD1Y2nQKGMR5zDQM=
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA=
github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw=
github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg=
github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/onsi/ginkgo/v2 v2.22.1 h1:QW7tbJAUDyVDVOM5dFa7qaybo+CRfR7bemlQUN6Z8aM=
github.com/onsi/ginkgo/v2 v2.22.1/go.mod h1:S6aTpoRsSq2cZOd+pssHAlKW/Q/jZt6cPrPlnj4a1xM=
github.com/onsi/gomega v1.36.1 h1:bJDPBO7ibjxcbHMgSCoo4Yj18UWbKDlLwX1x9sybDcw=
github.com/onsi/gomega v1.36.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo=
golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c=
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8=
golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw=
gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0=
gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
0707010000001B000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001900000000dnsproxy-0.75.0/internal0707010000001C000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002300000000dnsproxy-0.75.0/internal/bootstrap0707010000001D000081A4000000000000000000000001679A649F00000ECA000000000000000000000000000000000000003000000000dnsproxy-0.75.0/internal/bootstrap/bootstrap.go// Package bootstrap provides types and functions to resolve upstream hostnames
// and to dial retrieved addresses.
package bootstrap

import (
	"context"
	"fmt"
	"log/slog"
	"net"
	"net/netip"
	"net/url"
	"slices"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
)

// Network is a network type for use in [Resolver]'s methods.
type Network = string

const (
	// NetworkIP is a network type for both address families.
	NetworkIP Network = "ip"

	// NetworkIP4 is a network type for IPv4 address family.
	NetworkIP4 Network = "ip4"

	// NetworkIP6 is a network type for IPv6 address family.
	NetworkIP6 Network = "ip6"

	// NetworkTCP is a network type for TCP connections.
	NetworkTCP Network = "tcp"

	// NetworkUDP is a network type for UDP connections.
	NetworkUDP Network = "udp"
)

// DialHandler is a dial function for creating unencrypted network connections
// to the upstream server.  It establishes the connection to the server
// specified at initialization and ignores the addr.  network must be one of
// [NetworkTCP] or [NetworkUDP].
type DialHandler func(ctx context.Context, network Network, addr string) (conn net.Conn, err error)

// ResolveDialContext returns a DialHandler that uses addresses resolved from u
// using resolver.  l and u must not be nil.
func ResolveDialContext(
	u *url.URL,
	timeout time.Duration,
	r Resolver,
	preferV6 bool,
	l *slog.Logger,
) (h DialHandler, err error) {
	defer func() { err = errors.Annotate(err, "dialing %q: %w", u.Host) }()

	host, port, err := netutil.SplitHostPort(u.Host)
	if err != nil {
		// Don't wrap the error since it's informative enough as is and there is
		// already deferred annotation here.
		return nil, err
	}

	if r == nil {
		return nil, fmt.Errorf("resolver is nil: %w", ErrNoResolvers)
	}

	ctx := context.Background()
	if timeout > 0 {
		var cancel func()
		ctx, cancel = context.WithTimeout(ctx, timeout)
		defer cancel()
	}

	// TODO(e.burkov):  Use network properly, perhaps, pass it through options.
	ips, err := r.LookupNetIP(ctx, NetworkIP, host)
	if err != nil {
		return nil, fmt.Errorf("resolving hostname: %w", err)
	}

	if preferV6 {
		slices.SortStableFunc(ips, netutil.PreferIPv6)
	} else {
		slices.SortStableFunc(ips, netutil.PreferIPv4)
	}

	addrs := make([]string, 0, len(ips))
	for _, ip := range ips {
		addrs = append(addrs, netip.AddrPortFrom(ip, port).String())
	}

	return NewDialContext(timeout, l, addrs...), nil
}

// NewDialContext returns a DialHandler that dials addrs and returns the first
// successful connection.  At least a single addr should be specified.  l must
// not be nil.
func NewDialContext(timeout time.Duration, l *slog.Logger, addrs ...string) (h DialHandler) {
	addrLen := len(addrs)
	if addrLen == 0 {
		l.Debug("no addresses to dial")

		return func(_ context.Context, _, _ string) (conn net.Conn, err error) {
			return nil, errors.Error("no addresses")
		}
	}

	dialer := &net.Dialer{
		Timeout: timeout,
	}

	return func(ctx context.Context, network Network, _ string) (conn net.Conn, err error) {
		var errs []error

		// Return first succeeded connection.  Note that we're using addrs
		// instead of what's passed to the function.
		for i, addr := range addrs {
			a := l.With("addr", addr)
			a.DebugContext(ctx, "dialing", "idx", i+1, "total", addrLen)

			start := time.Now()
			conn, err = dialer.DialContext(ctx, network, addr)
			elapsed := time.Since(start)
			if err != nil {
				a.DebugContext(ctx, "connection failed", "elapsed", elapsed, slogutil.KeyError, err)
				errs = append(errs, err)

				continue
			}

			a.DebugContext(ctx, "connection succeeded", "elapsed", elapsed)

			return conn, nil
		}

		return nil, errors.Join(errs...)
	}
}
0707010000001E000081A4000000000000000000000001679A649F0000113C000000000000000000000000000000000000003500000000dnsproxy-0.75.0/internal/bootstrap/bootstrap_test.gopackage bootstrap_test

import (
	"context"
	"net"
	"net/netip"
	"net/url"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// testTimeout is a common timeout used in tests of this package.
const testTimeout = 1 * time.Second

// newListener creates a new listener of zero address of the specified network
// type and returns it, adding it's closing to the test cleanup.  sig is used to
// send the address of each accepted connection and must be read properly.
func newListener(t testing.TB, network string, sig chan net.Addr) (ipp netip.AddrPort) {
	t.Helper()

	// TODO(e.burkov):  Listen IPv6 as well, when the CI adds IPv6 interfaces.
	l, err := net.Listen(network, "127.0.0.1:0")
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, l.Close)

	go func() {
		pt := testutil.PanicT{}
		for c, lerr := l.Accept(); !errors.Is(lerr, net.ErrClosed); c, lerr = l.Accept() {
			require.NoError(pt, lerr)

			testutil.RequireSend(pt, sig, c.LocalAddr(), testTimeout)

			require.NoError(pt, c.Close())
		}
	}()

	ipp, err = netip.ParseAddrPort(l.Addr().String())
	require.NoError(t, err)

	return ipp
}

// See the details here: https://github.com/AdguardTeam/dnsproxy/issues/18
func TestResolveDialContext(t *testing.T) {
	sig := make(chan net.Addr, 1)

	ipp := newListener(t, "tcp", sig)
	port := ipp.Port()

	l := slogutil.NewDiscardLogger()

	testCases := []struct {
		name       string
		addresses  []netip.Addr
		preferIPv6 bool
	}{{
		name:       "v4",
		addresses:  []netip.Addr{netutil.IPv4Localhost()},
		preferIPv6: false,
	}, {
		name:       "both_prefer_v6",
		addresses:  []netip.Addr{netutil.IPv4Localhost(), netutil.IPv6Localhost()},
		preferIPv6: true,
	}, {
		name:       "both_prefer_v4",
		addresses:  []netip.Addr{netutil.IPv6Localhost(), netutil.IPv4Localhost()},
		preferIPv6: false,
	}, {
		name:       "strip_invalid",
		addresses:  []netip.Addr{{}, netutil.IPv4Localhost(), {}, netutil.IPv6Localhost(), {}},
		preferIPv6: true,
	}}

	const hostname = "host.name"

	pt := testutil.PanicT{}

	for _, tc := range testCases {
		r := &testResolver{
			onLookupNetIP: func(
				_ context.Context,
				network string,
				host string,
			) (addrs []netip.Addr, err error) {
				require.Equal(pt, bootstrap.NetworkIP, network)
				require.Equal(pt, hostname, host)

				return tc.addresses, nil
			},
		}

		t.Run(tc.name, func(t *testing.T) {
			dialContext, err := bootstrap.ResolveDialContext(
				&url.URL{Host: netutil.JoinHostPort(hostname, port)},
				testTimeout,
				bootstrap.ParallelResolver{r},
				tc.preferIPv6,
				l,
			)
			require.NoError(t, err)

			conn, err := dialContext(context.Background(), bootstrap.NetworkTCP, "")
			require.NoError(t, err)

			expected, ok := testutil.RequireReceive(t, sig, testTimeout)
			require.True(t, ok)

			assert.Equal(t, expected.String(), conn.RemoteAddr().String())
		})
	}

	t.Run("no_addresses", func(t *testing.T) {
		r := &testResolver{
			onLookupNetIP: func(
				_ context.Context,
				network string,
				host string,
			) (addrs []netip.Addr, err error) {
				require.Equal(pt, bootstrap.NetworkIP, network)
				require.Equal(pt, hostname, host)

				return nil, nil
			},
		}

		dialContext, err := bootstrap.ResolveDialContext(
			&url.URL{Host: netutil.JoinHostPort(hostname, port)},
			testTimeout,
			bootstrap.ParallelResolver{r},
			false,
			l,
		)
		require.NoError(t, err)

		_, err = dialContext(context.Background(), bootstrap.NetworkTCP, "")
		testutil.AssertErrorMsg(t, "no addresses", err)
	})

	t.Run("bad_hostname", func(t *testing.T) {
		const errMsg = `dialing "bad hostname": address bad hostname: ` +
			`missing port in address`

		dialContext, err := bootstrap.ResolveDialContext(
			&url.URL{Host: "bad hostname"},
			testTimeout,
			nil,
			false,
			l,
		)
		testutil.AssertErrorMsg(t, errMsg, err)

		assert.Nil(t, dialContext)
	})

	t.Run("no_resolvers", func(t *testing.T) {
		dialContext, err := bootstrap.ResolveDialContext(
			&url.URL{Host: netutil.JoinHostPort(hostname, port)},
			testTimeout,
			nil,
			false,
			l,
		)
		assert.ErrorIs(t, err, bootstrap.ErrNoResolvers)
		assert.Nil(t, dialContext)
	})
}
0707010000001F000081A4000000000000000000000001679A649F000000BC000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/bootstrap/error.gopackage bootstrap

import "github.com/AdguardTeam/golibs/errors"

// ErrNoResolvers is returned when zero resolvers specified.
const ErrNoResolvers errors.Error = "no resolvers specified"
07070100000020000081A4000000000000000000000001679A649F00000FE3000000000000000000000000000000000000002F00000000dnsproxy-0.75.0/internal/bootstrap/resolver.gopackage bootstrap

import (
	"context"
	"fmt"
	"log/slog"
	"net"
	"net/netip"
	"slices"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
)

// Resolver resolves the hostnames to IP addresses.  Note, that [net.Resolver]
// from standard library also implements this interface.
type Resolver interface {
	// LookupNetIP looks up the IP addresses for the given host.  network should
	// be one of [NetworkIP], [NetworkIP4] or [NetworkIP6].  The response may be
	// empty even if err is nil.  All the addrs must be valid.
	LookupNetIP(ctx context.Context, network Network, host string) (addrs []netip.Addr, err error)
}

// type check
var _ Resolver = &net.Resolver{}

// ParallelResolver is a slice of resolvers that are queried concurrently.  The
// first successful response is returned.
type ParallelResolver []Resolver

// type check
var _ Resolver = ParallelResolver(nil)

// LookupNetIP implements the [Resolver] interface for ParallelResolver.
func (r ParallelResolver) LookupNetIP(
	ctx context.Context,
	network Network,
	host string,
) (addrs []netip.Addr, err error) {
	resolversNum := len(r)
	switch resolversNum {
	case 0:
		return nil, ErrNoResolvers
	case 1:
		return r[0].LookupNetIP(ctx, network, host)
	default:
		// Go on.
	}

	// Size of channel must accommodate results of lookups from all resolvers,
	// sending into channel will block otherwise.
	ch := make(chan any, resolversNum)
	for _, rslv := range r {
		go lookupAsync(ctx, rslv, network, host, ch)
	}

	var errs []error
	for range r {
		switch result := <-ch; result := result.(type) {
		case error:
			errs = append(errs, result)
		case []netip.Addr:
			return result, nil
		}
	}

	return nil, errors.Join(errs...)
}

// recoverAndLog is a deferred helper that recovers from a panic and logs the
// panic value with the logger from context or with a default logger.  Sends the
// recovered value into resCh.
//
// TODO(a.garipov): Move this helper to golibs.
func recoverAndLog(ctx context.Context, resCh chan<- any) {
	v := recover()
	if v == nil {
		return
	}

	err, ok := v.(error)
	if !ok {
		err = fmt.Errorf("error value: %v", v)
	}

	l, ok := slogutil.LoggerFromContext(ctx)
	if !ok {
		l = slog.Default()
	}

	l.ErrorContext(ctx, "recovered panic", slogutil.KeyError, err)
	slogutil.PrintStack(ctx, l, slog.LevelError)

	resCh <- err
}

// lookupAsync performs a lookup for ip of host with r and sends the result into
// resCh.  It is intended to be used as a goroutine.
func lookupAsync(ctx context.Context, r Resolver, network, host string, resCh chan<- any) {
	// TODO(d.kolyshev): Propose better solution to recover without requiring
	// logger in the context.
	defer recoverAndLog(ctx, resCh)

	addrs, err := r.LookupNetIP(ctx, network, host)
	if err != nil {
		resCh <- err
	} else {
		resCh <- addrs
	}
}

// ConsequentResolver is a slice of resolvers that are queried in order until
// the first successful non-empty response, as opposed to just successful
// response requirement in [ParallelResolver].
type ConsequentResolver []Resolver

// type check
var _ Resolver = ConsequentResolver(nil)

// LookupNetIP implements the [Resolver] interface for ConsequentResolver.
func (resolvers ConsequentResolver) LookupNetIP(
	ctx context.Context,
	network Network,
	host string,
) (addrs []netip.Addr, err error) {
	if len(resolvers) == 0 {
		return nil, ErrNoResolvers
	}

	var errs []error
	for _, r := range resolvers {
		addrs, err = r.LookupNetIP(ctx, network, host)
		if err == nil && len(addrs) > 0 {
			return addrs, nil
		}

		errs = append(errs, err)
	}

	return nil, errors.Join(errs...)
}

// StaticResolver is a resolver which always responds with an underlying slice
// of IP addresses regardless of host and network.
type StaticResolver []netip.Addr

// type check
var _ Resolver = StaticResolver(nil)

// LookupNetIP implements the [Resolver] interface for StaticResolver.
func (r StaticResolver) LookupNetIP(
	_ context.Context,
	_ Network,
	_ string,
) (addrs []netip.Addr, err error) {
	return slices.Clone(r), nil
}
07070100000021000081A4000000000000000000000001679A649F00000A9A000000000000000000000000000000000000003400000000dnsproxy-0.75.0/internal/bootstrap/resolver_test.gopackage bootstrap_test

import (
	"context"
	"net/netip"
	"strings"
	"testing"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// testResolver is the [Resolver] interface implementation for testing purposes.
type testResolver struct {
	onLookupNetIP func(ctx context.Context, network, host string) (addrs []netip.Addr, err error)
}

// LookupNetIP implements the [Resolver] interface for *testResolver.
func (r *testResolver) LookupNetIP(
	ctx context.Context,
	network string,
	host string,
) (addrs []netip.Addr, err error) {
	return r.onLookupNetIP(ctx, network, host)
}

func TestLookupParallel(t *testing.T) {
	const hostname = "host.name"

	t.Run("no_resolvers", func(t *testing.T) {
		addrs, err := bootstrap.ParallelResolver(nil).LookupNetIP(context.Background(), "ip", "")
		assert.ErrorIs(t, err, bootstrap.ErrNoResolvers)
		assert.Nil(t, addrs)
	})

	pt := testutil.PanicT{}
	hostAddrs := []netip.Addr{netutil.IPv4Localhost()}

	immediate := &testResolver{
		onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
			require.Equal(pt, hostname, host)
			require.Equal(pt, "ip", network)

			return hostAddrs, nil
		},
	}

	t.Run("one_resolver", func(t *testing.T) {
		addrs, err := bootstrap.ParallelResolver{immediate}.LookupNetIP(
			context.Background(),
			"ip",
			hostname,
		)
		require.NoError(t, err)

		assert.Equal(t, hostAddrs, addrs)
	})

	t.Run("two_resolvers", func(t *testing.T) {
		delayCh := make(chan struct{}, 1)
		delayed := &testResolver{
			onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
				require.Equal(pt, hostname, host)
				require.Equal(pt, "ip", network)

				testutil.RequireReceive(pt, delayCh, testTimeout)

				return []netip.Addr{netutil.IPv6Localhost()}, nil
			},
		}

		addrs, err := bootstrap.ParallelResolver{immediate, delayed}.LookupNetIP(
			context.Background(),
			"ip",
			hostname,
		)
		require.NoError(t, err)
		testutil.RequireSend(t, delayCh, struct{}{}, testTimeout)

		assert.Equal(t, hostAddrs, addrs)
	})

	t.Run("all_errors", func(t *testing.T) {
		err := assert.AnError
		errStr := err.Error()
		wantErrMsg := strings.Join([]string{errStr, errStr, errStr}, "\n")

		r := &testResolver{
			onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
				return nil, assert.AnError
			},
		}

		addrs, err := bootstrap.ParallelResolver{r, r, r}.LookupNetIP(
			context.Background(),
			"ip",
			hostname,
		)
		testutil.AssertErrorMsg(t, wantErrMsg, err)
		assert.Nil(t, addrs)
	})
}
07070100000022000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001D00000000dnsproxy-0.75.0/internal/cmd07070100000023000081A4000000000000000000000001679A649F000041D2000000000000000000000000000000000000002500000000dnsproxy-0.75.0/internal/cmd/args.gopackage cmd

import (
	"flag"
	"fmt"
	"io"
	"os"
	"slices"
	"strings"

	"github.com/AdguardTeam/dnsproxy/internal/version"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/osutil"
	"github.com/AdguardTeam/golibs/timeutil"
)

// Indexes to help with the [commandLineOptions] initialization.
const (
	configPathIdx = iota
	logOutputIdx
	tlsCertPathIdx
	tlsKeyPathIdx
	httpsServerNameIdx
	httpsUserinfoIdx
	dnsCryptConfigPathIdx
	ednsAddrIdx
	upstreamModeIdx
	listenAddrsIdx
	listenPortsIdx
	httpsListenPortsIdx
	tlsListenPortsIdx
	quicListenPortsIdx
	dnsCryptListenPortsIdx
	upstreamsIdx
	bootstrapDNSIdx
	fallbacksIdx
	privateRDNSUpstreamsIdx
	dns64PrefixIdx
	privateSubnetsIdx
	bogusNXDomainIdx
	hostsFilesIdx
	timeoutIdx
	cacheMinTTLIdx
	cacheMaxTTLIdx
	cacheSizeBytesIdx
	ratelimitIdx
	ratelimitSubnetLenIPv4Idx
	ratelimitSubnetLenIPv6Idx
	udpBufferSizeIdx
	maxGoRoutinesIdx
	tlsMinVersionIdx
	tlsMaxVersionIdx
	helpIdx
	hostsFileEnabledIdx
	pprofIdx
	versionIdx
	verboseIdx
	insecureIdx
	ipv6DisabledIdx
	http3Idx
	cacheOptimisticIdx
	cacheIdx
	refuseAnyIdx
	enableEDNSSubnetIdx
	dns64Idx
	usePrivateRDNSIdx
)

// commandLineOption contains information about a command-line option: its long
// and, if there is one, short forms, the value type, and the description.
type commandLineOption struct {
	description string
	long        string
	short       string
	valueType   string
}

// commandLineOptions are all command-line options currently supported by the
// binary.
var commandLineOptions = []*commandLineOption{
	configPathIdx: {
		description: "YAML configuration file. Minimal working configuration in config.yaml.dist." +
			" Options passed through command line will override the ones from this file.",
		long:      "config-path",
		short:     "",
		valueType: "path",
	},
	logOutputIdx: {
		description: `Path to the log file.`,
		long:        "output",
		short:       "o",
		valueType:   "path",
	},
	tlsCertPathIdx: {
		description: "Path to a file with the certificate chain.",
		long:        "tls-crt",
		short:       "c",
		valueType:   "path",
	},
	tlsKeyPathIdx: {
		description: "Path to a file with the private key.",
		long:        "tls-key",
		short:       "k",
		valueType:   "path",
	},
	httpsServerNameIdx: {
		description: "Set the Server header for the responses from the HTTPS server.",
		long:        "https-server-name",
		short:       "",
		valueType:   "name",
	},
	httpsUserinfoIdx: {
		description: "If set, all DoH queries are required to have this basic authentication " +
			"information.",
		long:      "https-userinfo",
		short:     "",
		valueType: "name",
	},
	dnsCryptConfigPathIdx: {
		description: "Path to a file with DNSCrypt configuration. You can generate one using " +
			"https://github.com/ameshkov/dnscrypt.",
		long:      "dnscrypt-config",
		short:     "g",
		valueType: "path",
	},
	ednsAddrIdx: {
		description: "Send EDNS Client Address.",
		long:        "edns-addr",
		short:       "",
		valueType:   "address",
	},
	upstreamModeIdx: {
		description: "Defines the upstreams logic mode, possible values: load_balance, parallel, " +
			"fastest_addr (default: load_balance).",
		long:      "upstream-mode",
		short:     "",
		valueType: "mode",
	},
	listenAddrsIdx: {
		description: "Listening addresses.",
		long:        "listen",
		short:       "l",
		valueType:   "address",
	},
	listenPortsIdx: {
		description: "Listening ports. Zero value disables TCP and UDP listeners.",
		long:        "port",
		short:       "p",
		valueType:   "port",
	},
	httpsListenPortsIdx: {
		description: "Listening ports for DNS-over-HTTPS.",
		long:        "https-port",
		short:       "s",
		valueType:   "port",
	},
	tlsListenPortsIdx: {
		description: "Listening ports for DNS-over-TLS.",
		long:        "tls-port",
		short:       "t",
		valueType:   "port",
	},
	quicListenPortsIdx: {
		description: "Listening ports for DNS-over-QUIC.",
		long:        "quic-port",
		short:       "q",
		valueType:   "port",
	},
	dnsCryptListenPortsIdx: {
		description: "Listening ports for DNSCrypt.",
		long:        "dnscrypt-port",
		short:       "y",
		valueType:   "port",
	},
	upstreamsIdx: {
		description: "An upstream to be used (can be specified multiple times). You can also " +
			"specify path to a file with the list of servers.",
		long:      "upstream",
		short:     "u",
		valueType: "",
	},
	bootstrapDNSIdx: {
		description: "Bootstrap DNS for DoH and DoT, can be specified multiple times (default: " +
			"use system-provided).",
		long:      "bootstrap",
		short:     "b",
		valueType: "",
	},
	fallbacksIdx: {
		description: "Fallback resolvers to use when regular ones are unavailable, can be " +
			"specified multiple times. You can also specify path to a file with the list of servers.",
		long:      "fallback",
		short:     "f",
		valueType: "",
	},
	privateRDNSUpstreamsIdx: {
		description: "Private DNS upstreams to use for reverse DNS lookups of private addresses, " +
			"can be specified multiple times.",
		long:      "private-rdns-upstream",
		short:     "",
		valueType: "",
	},
	dns64PrefixIdx: {
		description: "Prefix used to handle DNS64. If not specified, dnsproxy uses the " +
			"'Well-Known Prefix' 64:ff9b::.  Can be specified multiple times.",
		long:      "dns64-prefix",
		short:     "",
		valueType: "subnet",
	},
	privateSubnetsIdx: {
		description: "Private subnets to use for reverse DNS lookups of private addresses.",
		long:        "private-subnets",
		short:       "",
		valueType:   "subnet",
	},
	bogusNXDomainIdx: {
		description: "Transform the responses containing at least a single IP that matches " +
			"specified addresses and CIDRs into NXDOMAIN.  Can be specified multiple times.",
		long:      "bogus-nxdomain",
		short:     "",
		valueType: "subnet",
	},
	hostsFilesIdx: {
		description: "List of paths to the hosts files, can be specified multiple times.",
		long:        "hosts-files",
		short:       "",
		valueType:   "path",
	},
	timeoutIdx: {
		description: "Timeout for outbound DNS queries to remote upstream servers in a " +
			"human-readable form",
		long:      "timeout",
		short:     "",
		valueType: "duration",
	},
	cacheMinTTLIdx: {
		description: "Minimum TTL value for DNS entries, in seconds. Capped at 3600. " +
			"Artificially extending TTLs should only be done with careful consideration.",
		long:      "cache-min-ttl",
		short:     "",
		valueType: "uint32",
	},
	cacheMaxTTLIdx: {
		description: "Maximum TTL value for DNS entries, in seconds.",
		long:        "cache-max-ttl",
		short:       "",
		valueType:   "uint32",
	},
	cacheSizeBytesIdx: {
		description: "Cache size (in bytes). Default: 64k.",
		long:        "cache-size",
		short:       "",
		valueType:   "int",
	},
	ratelimitIdx: {
		description: "Ratelimit (requests per second).",
		long:        "ratelimit",
		short:       "r",
		valueType:   "int",
	},
	ratelimitSubnetLenIPv4Idx: {
		description: "Ratelimit subnet length for IPv4.",
		long:        "ratelimit-subnet-len-ipv4",
		short:       "",
		valueType:   "int",
	},
	ratelimitSubnetLenIPv6Idx: {
		description: "Ratelimit subnet length for IPv6.",
		long:        "ratelimit-subnet-len-ipv6",
		short:       "",
		valueType:   "int",
	},
	udpBufferSizeIdx: {
		description: "Set the size of the UDP buffer in bytes. A value <= 0 will use the system " +
			"default.",
		long:      "udp-buf-size",
		short:     "",
		valueType: "int",
	},
	maxGoRoutinesIdx: {
		description: "Set the maximum number of go routines. A zero value will not not set a " +
			"maximum.",
		long:      "max-go-routines",
		short:     "",
		valueType: "uint",
	},
	tlsMinVersionIdx: {
		description: "Minimum TLS version, for example 1.0.",
		long:        "tls-min-version",
		short:       "",
		valueType:   "version",
	},
	tlsMaxVersionIdx: {
		description: "Maximum TLS version, for example 1.3.",
		long:        "tls-max-version",
		short:       "",
		valueType:   "version",
	},
	helpIdx: {
		description: "Print this help message and quit.",
		long:        "help",
		short:       "h",
		valueType:   "",
	},
	hostsFileEnabledIdx: {
		description: "If specified, use hosts files for resolving.",
		long:        "hosts-file-enabled",
		short:       "",
		valueType:   "",
	},
	pprofIdx: {
		description: "If present, exposes pprof information on localhost:6060.",
		long:        "pprof",
		short:       "",
		valueType:   "",
	},
	versionIdx: {
		description: "Prints the program version.",
		long:        "version",
		short:       "",
		valueType:   "",
	},
	verboseIdx: {
		description: "Verbose output.",
		long:        "verbose",
		short:       "v",
		valueType:   "",
	},
	insecureIdx: {
		description: "Disable secure TLS certificate validation.",
		long:        "insecure",
		short:       "",
		valueType:   "",
	},
	ipv6DisabledIdx: {
		description: "If specified, all AAAA requests will be replied with NoError RCode and " +
			"empty answer.",
		long:      "ipv6-disabled",
		short:     "",
		valueType: "",
	},
	http3Idx: {
		description: "Enable HTTP/3 support.",
		long:        "http3",
		short:       "",
		valueType:   "",
	},
	cacheOptimisticIdx: {
		description: "If specified, optimistic DNS cache is enabled.",
		long:        "cache-optimistic",
		short:       "",
		valueType:   "",
	},
	cacheIdx: {
		description: "If specified, DNS cache is enabled.",
		long:        "cache",
		short:       "",
		valueType:   "",
	},
	refuseAnyIdx: {
		description: "If specified, refuses ANY requests.",
		long:        "refuse-any",
		short:       "",
		valueType:   "",
	},
	enableEDNSSubnetIdx: {
		description: "Use EDNS Client Subnet extension.",
		long:        "edns",
		short:       "",
		valueType:   "",
	},
	dns64Idx: {
		description: "If specified, dnsproxy will act as a DNS64 server.",
		long:        "dns64",
		short:       "",
		valueType:   "",
	},
	usePrivateRDNSIdx: {
		description: "If specified, use private upstreams for reverse DNS lookups of private " +
			"addresses.",
		long:      "use-private-rdns",
		short:     "",
		valueType: "",
	},
}

// parseCmdLineOptions parses the command-line options.  conf must not be nil.
func parseCmdLineOptions(conf *configuration) (err error) {
	cmdName, args := os.Args[0], os.Args[1:]

	flags := flag.NewFlagSet(cmdName, flag.ContinueOnError)
	for i, fieldPtr := range []any{
		configPathIdx:             &conf.ConfigPath,
		logOutputIdx:              &conf.LogOutput,
		tlsCertPathIdx:            &conf.TLSCertPath,
		tlsKeyPathIdx:             &conf.TLSKeyPath,
		httpsServerNameIdx:        &conf.HTTPSServerName,
		httpsUserinfoIdx:          &conf.HTTPSUserinfo,
		dnsCryptConfigPathIdx:     &conf.DNSCryptConfigPath,
		ednsAddrIdx:               &conf.EDNSAddr,
		upstreamModeIdx:           &conf.UpstreamMode,
		listenAddrsIdx:            &conf.ListenAddrs,
		listenPortsIdx:            &conf.ListenPorts,
		httpsListenPortsIdx:       &conf.HTTPSListenPorts,
		tlsListenPortsIdx:         &conf.TLSListenPorts,
		quicListenPortsIdx:        &conf.QUICListenPorts,
		dnsCryptListenPortsIdx:    &conf.DNSCryptListenPorts,
		upstreamsIdx:              &conf.Upstreams,
		bootstrapDNSIdx:           &conf.BootstrapDNS,
		fallbacksIdx:              &conf.Fallbacks,
		privateRDNSUpstreamsIdx:   &conf.PrivateRDNSUpstreams,
		dns64PrefixIdx:            &conf.DNS64Prefix,
		privateSubnetsIdx:         &conf.PrivateSubnets,
		bogusNXDomainIdx:          &conf.BogusNXDomain,
		hostsFilesIdx:             &conf.HostsFiles,
		timeoutIdx:                &conf.Timeout,
		cacheMinTTLIdx:            &conf.CacheMinTTL,
		cacheMaxTTLIdx:            &conf.CacheMaxTTL,
		cacheSizeBytesIdx:         &conf.CacheSizeBytes,
		ratelimitIdx:              &conf.Ratelimit,
		ratelimitSubnetLenIPv4Idx: &conf.RatelimitSubnetLenIPv4,
		ratelimitSubnetLenIPv6Idx: &conf.RatelimitSubnetLenIPv6,
		udpBufferSizeIdx:          &conf.UDPBufferSize,
		maxGoRoutinesIdx:          &conf.MaxGoRoutines,
		tlsMinVersionIdx:          &conf.TLSMinVersion,
		tlsMaxVersionIdx:          &conf.TLSMaxVersion,
		helpIdx:                   &conf.help,
		hostsFileEnabledIdx:       &conf.HostsFileEnabled,
		pprofIdx:                  &conf.Pprof,
		versionIdx:                &conf.Version,
		verboseIdx:                &conf.Verbose,
		insecureIdx:               &conf.Insecure,
		ipv6DisabledIdx:           &conf.IPv6Disabled,
		http3Idx:                  &conf.HTTP3,
		cacheOptimisticIdx:        &conf.CacheOptimistic,
		cacheIdx:                  &conf.Cache,
		refuseAnyIdx:              &conf.RefuseAny,
		enableEDNSSubnetIdx:       &conf.EnableEDNSSubnet,
		dns64Idx:                  &conf.DNS64,
		usePrivateRDNSIdx:         &conf.UsePrivateRDNS,
	} {
		addOption(flags, fieldPtr, commandLineOptions[i])
	}

	flags.Usage = func() { usage(cmdName, os.Stderr) }

	err = flags.Parse(args)
	if err != nil {
		// Don't wrap the error, because it's informative enough as is.
		return err
	}

	nonFlags := flags.Args()
	if len(nonFlags) > 0 {
		return fmt.Errorf("positional arguments are not allowed, please check your command line "+
			"arguments; detected positional arguments: %s", nonFlags)
	}

	return nil
}

// defineFlag defines a flag with specified setFlag function.  o must not be
// nil.
func defineFlag[T any](
	fieldPtr *T,
	o *commandLineOption,
	setFlag func(p *T, name string, value T, usage string),
) {
	setFlag(fieldPtr, o.long, *fieldPtr, o.description)
	if o.short != "" {
		setFlag(fieldPtr, o.short, *fieldPtr, o.description)
	}
}

// defineFlagVar defines a flag with the specified [flag.Value] value.  o must
// not be nil.
func defineFlagVar(flags *flag.FlagSet, value flag.Value, o *commandLineOption) {
	flags.Var(value, o.long, o.description)
	if o.short != "" {
		flags.Var(value, o.short, o.description)
	}
}

// defineTimeutilDurationFlag defines a flag with for the specified
// [*timeutil.Duration] pointer and command line option.  o must not be nil.
func defineTimeutilDurationFlag(
	flags *flag.FlagSet,
	fieldPtr *timeutil.Duration,
	o *commandLineOption,
) {
	flags.TextVar(fieldPtr, o.long, *fieldPtr, o.description)
	if o.short != "" {
		flags.TextVar(fieldPtr, o.short, *fieldPtr, o.description)
	}
}

// addOption adds the command-line option described by o to flags using fieldPtr
// as the pointer to the value.
func addOption(flags *flag.FlagSet, fieldPtr any, o *commandLineOption) {
	switch fieldPtr := fieldPtr.(type) {
	case *string:
		defineFlag(fieldPtr, o, flags.StringVar)
	case *bool:
		defineFlag(fieldPtr, o, flags.BoolVar)
	case *int:
		defineFlag(fieldPtr, o, flags.IntVar)
	case *uint:
		defineFlag(fieldPtr, o, flags.UintVar)
	case *uint32:
		defineFlagVar(flags, (*uint32Value)(fieldPtr), o)
	case *float32:
		defineFlagVar(flags, (*float32Value)(fieldPtr), o)
	case *[]int:
		defineFlagVar(flags, newIntSliceValue(fieldPtr), o)
	case *[]string:
		defineFlagVar(flags, newStringSliceValue(fieldPtr), o)
	case *timeutil.Duration:
		defineTimeutilDurationFlag(flags, fieldPtr, o)
	default:
		panic(fmt.Errorf("unexpected field pointer type %T: %w", fieldPtr, errors.ErrBadEnumValue))
	}
}

// usage prints a usage message similar to the one printed by package flag but
// taking long vs. short versions into account as well as using more informative
// value hints.
func usage(cmdName string, output io.Writer) {
	options := slices.Clone(commandLineOptions)
	slices.SortStableFunc(options, func(a, b *commandLineOption) (res int) {
		return strings.Compare(a.long, b.long)
	})

	b := &strings.Builder{}
	_, _ = fmt.Fprintf(b, "Usage of %s:\n", cmdName)

	for _, o := range options {
		writeUsageLine(b, o)

		// Use four spaces before the tab to trigger good alignment for both 4-
		// and 8-space tab stops.
		_, _ = fmt.Fprintf(b, "    \t%s\n", o.description)
	}

	_, _ = io.WriteString(output, b.String())
}

// writeUsageLine writes the usage line for the provided command-line option.
func writeUsageLine(b *strings.Builder, o *commandLineOption) {
	if o.short == "" {
		if o.valueType == "" {
			_, _ = fmt.Fprintf(b, "  --%s\n", o.long)
		} else {
			_, _ = fmt.Fprintf(b, "  --%s=%s\n", o.long, o.valueType)
		}

		return
	}

	if o.valueType == "" {
		_, _ = fmt.Fprintf(b, "  --%s/-%s\n", o.long, o.short)
	} else {
		_, _ = fmt.Fprintf(b, "  --%[1]s=%[3]s/-%[2]s %[3]s\n", o.long, o.short, o.valueType)
	}
}

// processCmdLineOptions decides if dnsproxy should exit depending on the
// results of command-line option parsing.
func processCmdLineOptions(conf *configuration, parseErr error) (exitCode int, needExit bool) {
	if parseErr != nil {
		// Assume that usage has already been printed.
		return osutil.ExitCodeArgumentError, true
	}

	if conf.help {
		usage(os.Args[0], os.Stdout)

		return osutil.ExitCodeSuccess, true
	}

	if conf.Version {
		fmt.Printf("dnsproxy version %s\n", version.Version())

		return osutil.ExitCodeSuccess, true
	}

	return osutil.ExitCodeSuccess, false
}
07070100000024000081A4000000000000000000000001679A649F0000109E000000000000000000000000000000000000002400000000dnsproxy-0.75.0/internal/cmd/cmd.go// Package cmd is the dnsproxy CLI entry point.
package cmd

import (
	"context"
	"fmt"
	"log/slog"
	"net/http"
	"net/http/pprof"
	"os"
	"os/signal"
	"syscall"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/version"
	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/osutil"
)

// Main is the entrypoint of dnsproxy CLI.  Main may accept arguments, such as
// embedded assets and command-line arguments.
func Main() {
	conf, exitCode, err := parseConfig()
	if err != nil {
		_, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("parsing options: %w", err))
	}

	if conf == nil {
		os.Exit(exitCode)
	}

	logOutput := os.Stdout
	if conf.LogOutput != "" {
		// #nosec G302 -- Trust the file path that is given in the
		// configuration.
		logOutput, err = os.OpenFile(conf.LogOutput, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
		if err != nil {
			_, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("cannot create a log file: %s", err))

			os.Exit(osutil.ExitCodeArgumentError)
		}

		defer func() { _ = logOutput.Close() }()
	}

	lvl := slog.LevelInfo
	if conf.Verbose {
		lvl = slog.LevelDebug
	}

	l := slogutil.New(&slogutil.Config{
		Output: logOutput,
		Format: slogutil.FormatDefault,
		Level:  lvl,
		// TODO(d.kolyshev): Consider making configurable.
		AddTimestamp: true,
	})

	ctx := context.Background()

	if conf.Pprof {
		runPprof(l)
	}

	err = runProxy(ctx, l, conf)
	if err != nil {
		l.ErrorContext(ctx, "running dnsproxy", slogutil.KeyError, err)

		// As defers are skipped in case of os.Exit, close logOutput manually.
		//
		// TODO(a.garipov): Consider making logger.Close method.
		if logOutput != os.Stdout {
			_ = logOutput.Close()
		}

		os.Exit(osutil.ExitCodeFailure)
	}
}

// runProxy starts and runs the proxy.  l must not be nil.
//
// TODO(e.burkov):  Move into separate dnssvc package.
func runProxy(ctx context.Context, l *slog.Logger, conf *configuration) (err error) {
	var (
		buildVersion = version.Version()
		revision     = version.Revision()
		branch       = version.Branch()
		commitTime   = version.CommitTime()
	)

	l.InfoContext(
		ctx,
		"dnsproxy starting",
		"version", buildVersion,
		"revision", revision,
		"branch", branch,
		"commit_time", commitTime,
	)

	// Prepare the proxy server and its configuration.
	proxyConf, err := createProxyConfig(ctx, l, conf)
	if err != nil {
		return fmt.Errorf("configuring proxy: %w", err)
	}

	dnsProxy, err := proxy.New(proxyConf)
	if err != nil {
		return fmt.Errorf("creating proxy: %w", err)
	}

	// Start the proxy server.
	err = dnsProxy.Start(ctx)
	if err != nil {
		return fmt.Errorf("starting dnsproxy: %w", err)
	}

	// TODO(e.burkov):  Use [service.SignalHandler].
	signalChannel := make(chan os.Signal, 1)
	signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM)
	<-signalChannel

	// Stopping the proxy.
	err = dnsProxy.Shutdown(ctx)
	if err != nil {
		return fmt.Errorf("stopping dnsproxy: %w", err)
	}

	return nil
}

// runPprof runs pprof server on localhost:6060.
//
// TODO(e.burkov):  Use [httputil.RoutePprof].
func runPprof(l *slog.Logger) {
	mux := http.NewServeMux()
	mux.HandleFunc("/debug/pprof/", pprof.Index)
	mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
	mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
	mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
	mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
	mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
	mux.Handle("/debug/pprof/block", pprof.Handler("block"))
	mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
	mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
	mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
	mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))

	go func() {
		// TODO(d.kolyshev): Consider making configurable.
		pprofAddr := "localhost:6060"
		l.Info("starting pprof", "addr", pprofAddr)

		srv := &http.Server{
			Addr:        pprofAddr,
			ReadTimeout: 60 * time.Second,
			Handler:     mux,
		}

		err := srv.ListenAndServe()
		if err != nil && !errors.Is(err, http.ErrServerClosed) {
			l.Error("pprof failed to listen %v", "addr", pprofAddr, slogutil.KeyError, err)
		}
	}()
}
07070100000025000081A4000000000000000000000001679A649F0000215A000000000000000000000000000000000000002700000000dnsproxy-0.75.0/internal/cmd/config.gopackage cmd

import (
	"fmt"
	"os"
	"time"

	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/AdguardTeam/golibs/osutil"
	"github.com/AdguardTeam/golibs/timeutil"
	"gopkg.in/yaml.v3"
)

// configuration represents dnsproxy configuration.
type configuration struct {
	// ConfigPath is the path to the configuration file.
	ConfigPath string

	// LogOutput is the path to the log file.
	LogOutput string `yaml:"output"`

	// TLSCertPath is the path to the .crt with the certificate chain.
	TLSCertPath string `yaml:"tls-crt"`

	// TLSKeyPath is the path to the file with the private key.
	TLSKeyPath string `yaml:"tls-key"`

	// HTTPSServerName sets Server header for the HTTPS server.
	HTTPSServerName string `yaml:"https-server-name"`

	// HTTPSUserinfo is the sole permitted userinfo for the DoH basic
	// authentication.  If it is set, all DoH queries are required to have this
	// basic authentication information.
	HTTPSUserinfo string `yaml:"https-userinfo"`

	// DNSCryptConfigPath is the path to the DNSCrypt configuration file.
	DNSCryptConfigPath string `yaml:"dnscrypt-config"`

	// EDNSAddr is the custom EDNS Client Address to send.
	EDNSAddr string `yaml:"edns-addr"`

	// UpstreamMode determines the logic through which upstreams will be used.
	// If not specified the [proxy.UpstreamModeLoadBalance] is used.
	UpstreamMode string `yaml:"upstream-mode"`

	// ListenAddrs is the list of server's listen addresses.
	ListenAddrs []string `yaml:"listen-addrs"`

	// ListenPorts are the ports server listens on.
	ListenPorts []int `yaml:"listen-ports"`

	// HTTPSListenPorts are the ports server listens on for DNS-over-HTTPS.
	HTTPSListenPorts []int `yaml:"https-port"`

	// TLSListenPorts are the ports server listens on for DNS-over-TLS.
	TLSListenPorts []int `yaml:"tls-port"`

	// QUICListenPorts are the ports server listens on for DNS-over-QUIC.
	QUICListenPorts []int `yaml:"quic-port"`

	// DNSCryptListenPorts are the ports server listens on for DNSCrypt.
	DNSCryptListenPorts []int `yaml:"dnscrypt-port"`

	// Upstreams is the list of DNS upstream servers.
	Upstreams []string `yaml:"upstream"`

	// BootstrapDNS is the list of bootstrap DNS upstream servers.
	BootstrapDNS []string `yaml:"bootstrap"`

	// Fallbacks is the list of fallback DNS upstream servers.
	Fallbacks []string `yaml:"fallback"`

	// PrivateRDNSUpstreams are upstreams to use for reverse DNS lookups of
	// private addresses, including the requests for authority records, such as
	// SOA and NS.
	PrivateRDNSUpstreams []string `yaml:"private-rdns-upstream"`

	// DNS64Prefix defines the DNS64 prefixes that dnsproxy should use when it
	// acts as a DNS64 server.  If not specified, dnsproxy uses the default
	// Well-Known Prefix.  This option can be specified multiple times.
	DNS64Prefix []string `yaml:"dns64-prefix"`

	// PrivateSubnets is the list of private subnets to determine private
	// addresses.
	PrivateSubnets []string `yaml:"private-subnets"`

	// BogusNXDomain transforms responses that contain at least one of the given
	// IP addresses into NXDOMAIN.
	//
	// TODO(a.garipov): Find a way to use [netutil.Prefix].  Currently, package
	// go-flags doesn't support text unmarshalers.
	BogusNXDomain []string `yaml:"bogus-nxdomain"`

	// HostsFiles is the list of paths to the hosts files to resolve from.
	HostsFiles []string `yaml:"hosts-files"`

	// Timeout for outbound DNS queries to remote upstream servers in a
	// human-readable form.  Default is 10s.
	Timeout timeutil.Duration `yaml:"timeout"`

	// CacheMinTTL is the minimum TTL value for caching DNS entries, in seconds.
	// It overrides the TTL value from the upstream server, if the one is less.
	CacheMinTTL uint32 `yaml:"cache-min-ttl"`

	// CacheMaxTTL is the maximum TTL value for caching DNS entries, in seconds.
	// It overrides the TTL value from the upstream server, if the one is
	// greater.
	CacheMaxTTL uint32 `yaml:"cache-max-ttl"`

	// CacheSizeBytes is the cache size in bytes.  Default is 64k.
	CacheSizeBytes int `yaml:"cache-size"`

	// Ratelimit is the maximum number of requests per second.
	Ratelimit int `yaml:"ratelimit"`

	// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
	// rate limiting requests.
	RatelimitSubnetLenIPv4 int `yaml:"ratelimit-subnet-len-ipv4"`

	// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
	// rate limiting requests.
	RatelimitSubnetLenIPv6 int `yaml:"ratelimit-subnet-len-ipv6"`

	// UDPBufferSize is the size of the UDP buffer in bytes.  A value <= 0 will
	// use the system default.
	UDPBufferSize int `yaml:"udp-buf-size"`

	// MaxGoRoutines is the maximum number of goroutines.
	MaxGoRoutines uint `yaml:"max-go-routines"`

	// TLSMinVersion is the minimum allowed version of TLS.
	//
	// TODO(d.kolyshev): Use more suitable type.
	TLSMinVersion float32 `yaml:"tls-min-version"`

	// TLSMaxVersion is the maximum allowed version of TLS.
	//
	// TODO(d.kolyshev): Use more suitable type.
	TLSMaxVersion float32 `yaml:"tls-max-version"`

	// help, if true, prints the command-line option help message and quit with
	// a successful exit-code.
	help bool

	// HostsFileEnabled controls whether hosts files are used for resolving or
	// not.
	HostsFileEnabled bool `yaml:"hosts-file-enabled"`

	// Pprof defines whether the pprof information needs to be exposed via
	// localhost:6060 or not.
	Pprof bool `yaml:"pprof"`

	// Version, if true, prints the program version, and exits.
	Version bool `yaml:"version"`

	// Verbose controls the verbosity of the output.
	Verbose bool `yaml:"verbose"`

	// Insecure disables upstream servers TLS certificate verification.
	Insecure bool `yaml:"insecure"`

	// IPv6Disabled makes the server to respond with NODATA to all AAAA queries.
	IPv6Disabled bool `yaml:"ipv6-disabled"`

	// HTTP3 controls whether HTTP/3 is enabled for this instance of dnsproxy.
	// It enables HTTP/3 support for both the DoH upstreams and the DoH server.
	HTTP3 bool `yaml:"http3"`

	// CacheOptimistic, if set to true, enables the optimistic DNS cache. That
	// means that cached results will be served even if their cache TTL has
	// already expired.
	CacheOptimistic bool `yaml:"cache-optimistic"`

	// Cache controls whether DNS responses are cached or not.
	Cache bool `yaml:"cache"`

	// RefuseAny makes the server to refuse requests of type ANY.
	RefuseAny bool `yaml:"refuse-any"`

	// EnableEDNSSubnet uses EDNS Client Subnet extension.
	EnableEDNSSubnet bool `yaml:"edns"`

	// DNS64 defines whether DNS64 functionality is enabled or not.
	DNS64 bool `yaml:"dns64"`

	// UsePrivateRDNS makes the server to use private upstreams for reverse DNS
	// lookups of private addresses, including the requests for authority
	// records, such as SOA and NS.
	UsePrivateRDNS bool `yaml:"use-private-rdns"`
}

// parseConfig returns options parsed from the command args or config file.  If
// no options have been parsed, it returns a suitable exit code and an error.
func parseConfig() (conf *configuration, exitCode int, err error) {
	conf = &configuration{
		HTTPSServerName:        "dnsproxy",
		UpstreamMode:           string(proxy.UpstreamModeLoadBalance),
		CacheSizeBytes:         64 * 1024,
		Timeout:                timeutil.Duration(10 * time.Second),
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 56,
		HostsFileEnabled:       true,
	}

	err = parseCmdLineOptions(conf)
	exitCode, needExit := processCmdLineOptions(conf, err)
	if needExit {
		return nil, exitCode, err
	}

	confPath := conf.ConfigPath
	if confPath == "" {
		return conf, exitCode, nil
	}

	// TODO(d.kolyshev): Bootstrap and use slog.
	fmt.Printf("dnsproxy config path: %s\n", confPath)

	err = parseConfigFile(conf, confPath)
	if err != nil {
		return nil, osutil.ExitCodeFailure, fmt.Errorf(
			"parsing config file %s: %w",
			confPath,
			err,
		)
	}

	// Parse command-line args again as it has priority over YAML config.
	err = parseCmdLineOptions(conf)
	if err != nil {
		// Don't wrap the error, because it's informative enough as is.
		return nil, osutil.ExitCodeFailure, err
	}

	return conf, exitCode, nil
}

// parseConfigFile fills options with the settings from file read by the given
// path.
func parseConfigFile(conf *configuration, confPath string) (err error) {
	// #nosec G304 -- Trust the file path that is given in the args.
	b, err := os.ReadFile(confPath)
	if err != nil {
		return fmt.Errorf("reading file: %w", err)
	}

	err = yaml.Unmarshal(b, conf)
	if err != nil {
		return fmt.Errorf("unmarshalling file: %w", err)
	}

	return nil
}
07070100000026000081A4000000000000000000000001679A649F00000E99000000000000000000000000000000000000002500000000dnsproxy-0.75.0/internal/cmd/flag.gopackage cmd

import (
	"flag"
	"fmt"
	"strconv"
	"strings"

	"github.com/AdguardTeam/golibs/stringutil"
)

// uint32Value is an uint32 that can be defined as a flag for [flag.FlagSet].
type uint32Value uint32

// type check
var _ flag.Value = (*uint32Value)(nil)

// Set implements the [flag.Value] interface for *uint32Value.
func (i *uint32Value) Set(s string) (err error) {
	v, err := strconv.ParseUint(s, 0, 32)
	*i = uint32Value(v)

	return err
}

// String implements the [flag.Value] interface for *uint32Value.
func (i *uint32Value) String() (out string) {
	return strconv.FormatUint(uint64(*i), 10)
}

// float32Value is an float32 that can be defined as a flag for [flag.FlagSet].
type float32Value float32

// type check
var _ flag.Value = (*float32Value)(nil)

// Set implements the [flag.Value] interface for *float32Value.
func (i *float32Value) Set(s string) (err error) {
	v, err := strconv.ParseFloat(s, 32)
	*i = float32Value(v)

	return err
}

// String implements the [flag.Value] interface for *float32Value.
func (i *float32Value) String() (out string) {
	return strconv.FormatFloat(float64(*i), 'f', 3, 32)
}

// intSliceValue represent a struct with a slice of integers that can be defined
// as a flag for [flag.FlagSet].
type intSliceValue struct {
	// values is the pointer to a slice of integers to store parsed values.
	values *[]int

	// isSet is false until the corresponding flag is met for the first time.
	// When the flag is found, the default value is overwritten with zero value.
	isSet bool
}

// newIntSliceValue returns a pointer to intSliceValue with the given value.
func newIntSliceValue(p *[]int) (out *intSliceValue) {
	return &intSliceValue{
		values: p,
		isSet:  false,
	}
}

// type check
var _ flag.Value = (*intSliceValue)(nil)

// Set implements the [flag.Value] interface for *intSliceValue.
func (i *intSliceValue) Set(s string) (err error) {
	v, err := strconv.Atoi(s)
	if err != nil {
		return fmt.Errorf("parsing integer slice arg %q: %w", s, err)
	}

	if !i.isSet {
		i.isSet = true
		*i.values = []int{}
	}

	*i.values = append(*i.values, v)

	return nil
}

// String implements the [flag.Value] interface for *intSliceValue.
func (i *intSliceValue) String() (out string) {
	if i == nil || i.values == nil {
		return ""
	}

	sb := &strings.Builder{}
	for idx, v := range *i.values {
		if idx > 0 {
			stringutil.WriteToBuilder(sb, ",")
		}

		stringutil.WriteToBuilder(sb, strconv.Itoa(v))
	}

	return sb.String()
}

// stringSliceValue represent a struct with a slice of strings that can be
// defined as a flag for [flag.FlagSet].
type stringSliceValue struct {
	// values is the pointer to a slice of string to store parsed values.
	values *[]string

	// isSet is false until the corresponding flag is met for the first time.
	// When the flag is found, the default value is overwritten with zero value.
	isSet bool
}

// newStringSliceValue returns a pointer to stringSliceValue with the given
// value.
func newStringSliceValue(p *[]string) (out *stringSliceValue) {
	return &stringSliceValue{
		values: p,
		isSet:  false,
	}
}

// type check
var _ flag.Value = (*stringSliceValue)(nil)

// Set implements the [flag.Value] interface for *stringSliceValue.
func (i *stringSliceValue) Set(s string) (err error) {
	if !i.isSet {
		i.isSet = true
		*i.values = []string{}
	}

	*i.values = append(*i.values, s)

	return nil
}

// String implements the [flag.Value] interface for *stringSliceValue.
func (i *stringSliceValue) String() (out string) {
	if i == nil || i.values == nil {
		return ""
	}

	sb := &strings.Builder{}
	for idx, v := range *i.values {
		if idx > 0 {
			stringutil.WriteToBuilder(sb, ",")
		}

		stringutil.WriteToBuilder(sb, v)
	}

	return sb.String()
}
07070100000027000081A4000000000000000000000001679A649F00003AA8000000000000000000000000000000000000002600000000dnsproxy-0.75.0/internal/cmd/proxy.gopackage cmd

import (
	"context"
	"crypto/tls"
	"fmt"
	"log/slog"
	"net"
	"net/netip"
	"net/url"
	"os"
	"strings"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
	"github.com/AdguardTeam/dnsproxy/internal/handler"
	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/osutil"
	"github.com/ameshkov/dnscrypt/v2"
	"gopkg.in/yaml.v3"
)

// TODO(e.burkov):  Use a separate type for the YAML configuration file.

// createProxyConfig initializes [proxy.Config].  l must not be nil.
func createProxyConfig(
	ctx context.Context,
	l *slog.Logger,
	conf *configuration,
) (proxyConf *proxy.Config, err error) {
	hostsFiles, err := conf.hostsFiles(ctx, l)
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return nil, err
	}

	hosts, err := handler.ReadHosts(hostsFiles)
	if err != nil {
		return nil, fmt.Errorf("reading hosts files: %w", err)
	}

	reqHdlr := handler.NewDefault(&handler.DefaultConfig{
		Logger: l.With(slogutil.KeyPrefix, "default_handler"),
		// TODO(e.burkov):  Use the configured message constructor.
		MessageConstructor: dnsmsg.DefaultMessageConstructor{},
		HaltIPv6:           conf.IPv6Disabled,
		HostsFiles:         hosts,
	})

	proxyConf = &proxy.Config{
		Logger: l.With(slogutil.KeyPrefix, proxy.LogPrefix),

		RatelimitSubnetLenIPv4: conf.RatelimitSubnetLenIPv4,
		RatelimitSubnetLenIPv6: conf.RatelimitSubnetLenIPv6,

		Ratelimit:       conf.Ratelimit,
		CacheEnabled:    conf.Cache,
		CacheSizeBytes:  conf.CacheSizeBytes,
		CacheMinTTL:     conf.CacheMinTTL,
		CacheMaxTTL:     conf.CacheMaxTTL,
		CacheOptimistic: conf.CacheOptimistic,
		RefuseAny:       conf.RefuseAny,
		HTTP3:           conf.HTTP3,
		// TODO(e.burkov):  The following CIDRs are aimed to match any address.
		// This is not quite proper approach to be used by default so think
		// about configuring it.
		TrustedProxies: netutil.SliceSubnetSet{
			netip.MustParsePrefix("0.0.0.0/0"),
			netip.MustParsePrefix("::0/0"),
		},
		EnableEDNSClientSubnet: conf.EnableEDNSSubnet,
		UDPBufferSize:          conf.UDPBufferSize,
		HTTPSServerName:        conf.HTTPSServerName,
		MaxGoroutines:          conf.MaxGoRoutines,
		UsePrivateRDNS:         conf.UsePrivateRDNS,
		PrivateSubnets:         netutil.SubnetSetFunc(netutil.IsLocallyServed),
		RequestHandler:         reqHdlr.HandleRequest,
	}

	if uiStr := conf.HTTPSUserinfo; uiStr != "" {
		user, pass, ok := strings.Cut(uiStr, ":")
		if ok {
			proxyConf.Userinfo = url.UserPassword(user, pass)
		} else {
			proxyConf.Userinfo = url.User(user)
		}
	}

	conf.initBogusNXDomain(ctx, l, proxyConf)

	var errs []error
	errs = append(errs, conf.initUpstreams(ctx, l, proxyConf))
	errs = append(errs, conf.initEDNS(ctx, l, proxyConf))
	errs = append(errs, conf.initTLSConfig(proxyConf))
	errs = append(errs, conf.initDNSCryptConfig(proxyConf))
	errs = append(errs, conf.initListenAddrs(proxyConf))
	errs = append(errs, conf.initSubnets(proxyConf))

	return proxyConf, errors.Join(errs...)
}

// isEmpty returns false if uc contains at least a single upstream.  uc must not
// be nil.
//
// TODO(e.burkov):  Think of a better way to validate the config.  Perhaps,
// return an error from [ParseUpstreamsConfig] if no upstreams were initialized.
func isEmpty(uc *proxy.UpstreamConfig) (ok bool) {
	return len(uc.Upstreams) == 0 &&
		len(uc.DomainReservedUpstreams) == 0 &&
		len(uc.SpecifiedDomainUpstreams) == 0
}

// defaultLocalTimeout is the default timeout for local operations.
const defaultLocalTimeout = 1 * time.Second

// initUpstreams inits upstream-related config fields.
//
// TODO(d.kolyshev): Join errors.
func (conf *configuration) initUpstreams(
	ctx context.Context,
	l *slog.Logger,
	config *proxy.Config,
) (err error) {
	httpVersions := upstream.DefaultHTTPVersions
	if conf.HTTP3 {
		httpVersions = []upstream.HTTPVersion{
			upstream.HTTPVersion3,
			upstream.HTTPVersion2,
			upstream.HTTPVersion11,
		}
	}

	timeout := time.Duration(conf.Timeout)
	bootOpts := &upstream.Options{
		Logger:             l,
		HTTPVersions:       httpVersions,
		InsecureSkipVerify: conf.Insecure,
		Timeout:            timeout,
	}
	boot, err := initBootstrap(ctx, l, conf.BootstrapDNS, bootOpts)
	if err != nil {
		return fmt.Errorf("initializing bootstrap: %w", err)
	}

	upsOpts := &upstream.Options{
		Logger:             l,
		HTTPVersions:       httpVersions,
		InsecureSkipVerify: conf.Insecure,
		Bootstrap:          boot,
		Timeout:            timeout,
	}
	upstreams := loadServersList(conf.Upstreams)

	config.UpstreamConfig, err = proxy.ParseUpstreamsConfig(upstreams, upsOpts)
	if err != nil {
		return fmt.Errorf("parsing upstreams configuration: %w", err)
	}

	privateUpsOpts := &upstream.Options{
		Logger:       l,
		HTTPVersions: httpVersions,
		Bootstrap:    boot,
		Timeout:      min(defaultLocalTimeout, timeout),
	}
	privateUpstreams := loadServersList(conf.PrivateRDNSUpstreams)

	private, err := proxy.ParseUpstreamsConfig(privateUpstreams, privateUpsOpts)
	if err != nil {
		return fmt.Errorf("parsing private rdns upstreams configuration: %w", err)
	}

	if !isEmpty(private) {
		config.PrivateRDNSUpstreamConfig = private
	}

	fallbackUpstreams := loadServersList(conf.Fallbacks)
	fallbacks, err := proxy.ParseUpstreamsConfig(fallbackUpstreams, upsOpts)
	if err != nil {
		return fmt.Errorf("parsing fallback upstreams configuration: %w", err)
	}

	if !isEmpty(fallbacks) {
		config.Fallbacks = fallbacks
	}

	if conf.UpstreamMode != "" {
		err = config.UpstreamMode.UnmarshalText([]byte(conf.UpstreamMode))
		if err != nil {
			return fmt.Errorf("parsing upstream mode: %w", err)
		}

		return nil
	}

	config.UpstreamMode = proxy.UpstreamModeLoadBalance

	return nil
}

// initBootstrap initializes the [upstream.Resolver] for bootstrapping upstream
// servers.  It returns the default resolver if no bootstraps were specified.
// The returned resolver will also use system hosts files first.
func initBootstrap(
	ctx context.Context,
	l *slog.Logger,
	bootstraps []string,
	opts *upstream.Options,
) (r upstream.Resolver, err error) {
	var resolvers []upstream.Resolver

	for i, b := range bootstraps {
		var ur *upstream.UpstreamResolver
		ur, err = upstream.NewUpstreamResolver(b, opts)
		if err != nil {
			return nil, fmt.Errorf("creating bootstrap resolver at index %d: %w", i, err)
		}

		resolvers = append(resolvers, upstream.NewCachingResolver(ur))
	}

	switch len(resolvers) {
	case 0:
		etcHosts, hostsErr := upstream.NewDefaultHostsResolver(osutil.RootDirFS(), l)
		if hostsErr != nil {
			l.ErrorContext(ctx, "creating default hosts resolver", slogutil.KeyError, hostsErr)

			return net.DefaultResolver, nil
		}

		return upstream.ConsequentResolver{etcHosts, net.DefaultResolver}, nil
	case 1:
		return resolvers[0], nil
	default:
		return upstream.ParallelResolver(resolvers), nil
	}
}

// initEDNS inits EDNS-related config fields.
func (conf *configuration) initEDNS(
	ctx context.Context,
	l *slog.Logger,
	config *proxy.Config,
) (err error) {
	if conf.EDNSAddr == "" {
		return nil
	}

	if !conf.EnableEDNSSubnet {
		l.WarnContext(ctx, "--edns is required", "--edns-addr", conf.EDNSAddr)

		return nil
	}

	config.EDNSAddr, err = netutil.ParseIP(conf.EDNSAddr)
	if err != nil {
		return fmt.Errorf("parsing edns-addr: %w", err)
	}

	return nil
}

// initBogusNXDomain inits BogusNXDomain structure.
func (conf *configuration) initBogusNXDomain(ctx context.Context, l *slog.Logger, config *proxy.Config) {
	if len(conf.BogusNXDomain) == 0 {
		return
	}

	for i, s := range conf.BogusNXDomain {
		p, err := proxynetutil.ParseSubnet(s)
		if err != nil {
			// TODO(a.garipov): Consider returning this err as a proper error.
			l.WarnContext(ctx, "parsing bogus nxdomain", "index", i, slogutil.KeyError, err)
		} else {
			config.BogusNXDomain = append(config.BogusNXDomain, p)
		}
	}
}

// initTLSConfig inits the TLS config.
func (conf *configuration) initTLSConfig(config *proxy.Config) (err error) {
	if conf.TLSCertPath != "" && conf.TLSKeyPath != "" {
		var tlsConfig *tls.Config
		tlsConfig, err = newTLSConfig(conf)
		if err != nil {
			return fmt.Errorf("loading TLS config: %w", err)
		}

		config.TLSConfig = tlsConfig
	}

	return nil
}

// initDNSCryptConfig inits the DNSCrypt config.
func (conf *configuration) initDNSCryptConfig(config *proxy.Config) (err error) {
	if conf.DNSCryptConfigPath == "" {
		return
	}

	b, err := os.ReadFile(conf.DNSCryptConfigPath)
	if err != nil {
		return fmt.Errorf("reading DNSCrypt config %q: %w", conf.DNSCryptConfigPath, err)
	}

	rc := &dnscrypt.ResolverConfig{}
	err = yaml.Unmarshal(b, rc)
	if err != nil {
		return fmt.Errorf("unmarshalling DNSCrypt config: %w", err)
	}

	cert, err := rc.CreateCert()
	if err != nil {
		return fmt.Errorf("creating DNSCrypt certificate: %w", err)
	}

	config.DNSCryptResolverCert = cert
	config.DNSCryptProviderName = rc.ProviderName

	return nil
}

// parseListenAddrs returns a slice of listen IP addresses from the given
// options.  In case no addresses are specified by options returns a slice with
// the IPv4 unspecified address "0.0.0.0".
//
// TODO(d.kolyshev): Join errors.
func parseListenAddrs(addrStrs []string) (addrs []netip.Addr, err error) {
	for i, a := range addrStrs {
		var ip netip.Addr
		ip, err = netip.ParseAddr(a)
		if err != nil {
			return addrs, fmt.Errorf("parsing listen address at index %d: %s", i, a)
		}

		addrs = append(addrs, ip)
	}

	if len(addrs) == 0 {
		// If ListenAddrs has not been parsed through config file nor command
		// line we set it to "0.0.0.0".
		//
		// TODO(a.garipov): Consider using localhost.
		addrs = append(addrs, netip.IPv4Unspecified())
	}

	return addrs, nil
}

// initListenAddrs sets up proxy configuration listen IP addresses.
func (conf *configuration) initListenAddrs(config *proxy.Config) (err error) {
	addrs, err := parseListenAddrs(conf.ListenAddrs)
	if err != nil {
		return fmt.Errorf("parsing listen addresses: %w", err)
	}

	if len(conf.ListenPorts) == 0 {
		// If ListenPorts has not been parsed through config file nor command
		// line we set it to 53.
		conf.ListenPorts = []int{53}
	}

	for _, port := range conf.ListenPorts {
		for _, ip := range addrs {
			addrPort := netip.AddrPortFrom(ip, uint16(port))

			config.UDPListenAddr = append(config.UDPListenAddr, net.UDPAddrFromAddrPort(addrPort))
			config.TCPListenAddr = append(config.TCPListenAddr, net.TCPAddrFromAddrPort(addrPort))
		}
	}

	initTLSListenAddrs(config, conf, addrs)
	initDNSCryptListenAddrs(config, conf, addrs)

	return nil
}

// initTLSListenAddrs sets up proxy configuration TLS listen addresses.
func initTLSListenAddrs(proxyConf *proxy.Config, conf *configuration, addrs []netip.Addr) {
	if proxyConf.TLSConfig == nil {
		return
	}

	for _, ip := range addrs {
		for _, port := range conf.TLSListenPorts {
			a := net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port)))
			proxyConf.TLSListenAddr = append(proxyConf.TLSListenAddr, a)
		}

		for _, port := range conf.HTTPSListenPorts {
			a := net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port)))
			proxyConf.HTTPSListenAddr = append(proxyConf.HTTPSListenAddr, a)
		}

		for _, port := range conf.QUICListenPorts {
			a := net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port)))
			proxyConf.QUICListenAddr = append(proxyConf.QUICListenAddr, a)
		}
	}
}

// initDNSCryptListenAddrs sets up proxy configuration DNSCrypt listen
// addresses.
func initDNSCryptListenAddrs(proxyConf *proxy.Config, conf *configuration, addrs []netip.Addr) {
	if proxyConf.DNSCryptResolverCert == nil || proxyConf.DNSCryptProviderName == "" {
		return
	}

	for _, port := range conf.DNSCryptListenPorts {
		p := uint16(port)

		for _, ip := range addrs {
			addrPort := netip.AddrPortFrom(ip, p)

			tcp := net.TCPAddrFromAddrPort(addrPort)
			proxyConf.DNSCryptTCPListenAddr = append(proxyConf.DNSCryptTCPListenAddr, tcp)

			udp := net.UDPAddrFromAddrPort(addrPort)
			proxyConf.DNSCryptUDPListenAddr = append(proxyConf.DNSCryptUDPListenAddr, udp)
		}
	}
}

// initSubnets sets the DNS64 configuration into conf.
//
// TODO(d.kolyshev): Join errors.
func (conf *configuration) initSubnets(proxyConf *proxy.Config) (err error) {
	if proxyConf.UseDNS64 = conf.DNS64; proxyConf.UseDNS64 {
		for i, p := range conf.DNS64Prefix {
			var pref netip.Prefix
			pref, err = netip.ParsePrefix(p)
			if err != nil {
				return fmt.Errorf("parsing dns64 prefix at index %d: %w", i, err)
			}

			proxyConf.DNS64Prefs = append(proxyConf.DNS64Prefs, pref)
		}
	}

	if !conf.UsePrivateRDNS {
		return nil
	}

	return conf.initPrivateSubnets(proxyConf)
}

// initSubnets sets the private subnets configuration into conf.
func (conf *configuration) initPrivateSubnets(proxyConf *proxy.Config) (err error) {
	private := make([]netip.Prefix, 0, len(conf.PrivateSubnets))
	for i, p := range conf.PrivateSubnets {
		var pref netip.Prefix
		pref, err = netip.ParsePrefix(p)
		if err != nil {
			return fmt.Errorf("parsing private subnet at index %d: %w", i, err)
		}

		private = append(private, pref)
	}

	if len(private) > 0 {
		proxyConf.PrivateSubnets = netutil.SliceSubnetSet(private)
	}

	return nil
}

// loadServersList loads a list of DNS servers from the specified list.  The
// thing is that the user may specify either a server address or the path to a
// file with a list of addresses.  This method takes care of it, it reads the
// file and loads servers from this file if needed.
func loadServersList(sources []string) []string {
	var servers []string

	for _, source := range sources {
		// #nosec G304 -- Trust the file path that is given in the
		// configuration.
		data, err := os.ReadFile(source)
		if err != nil {
			// Ignore errors, just consider it a server address and not a file.
			servers = append(servers, source)
		}

		lines := strings.Split(string(data), "\n")
		for _, line := range lines {
			line = strings.TrimSpace(line)

			// Ignore comments in the file.
			if line == "" ||
				strings.HasPrefix(line, "!") ||
				strings.HasPrefix(line, "#") {
				continue
			}

			servers = append(servers, line)
		}
	}

	return servers
}

// hostsFiles returns the list of hosts files to resolve from.  It's empty if
// resolving from hosts files is disabled.
func (conf *configuration) hostsFiles(ctx context.Context, l *slog.Logger) (paths []string, err error) {
	if !conf.HostsFileEnabled {
		l.DebugContext(ctx, "hosts files are disabled")

		return nil, nil
	}

	l.DebugContext(ctx, "hosts files are enabled")

	if len(conf.HostsFiles) > 0 {
		return conf.HostsFiles, nil
	}

	paths, err = proxynetutil.DefaultHostsPaths()
	if err != nil {
		return nil, fmt.Errorf("getting default hosts files: %w", err)
	}

	l.DebugContext(ctx, "hosts files are not specified, using default", "paths", paths)

	return paths, nil
}
07070100000028000081A4000000000000000000000001679A649F0000079F000000000000000000000000000000000000002400000000dnsproxy-0.75.0/internal/cmd/tls.gopackage cmd

import (
	"crypto/tls"
	"fmt"
	"os"
)

// NewTLSConfig returns the TLS config that includes a certificate.  Use it for
// server TLS configuration or for a client certificate.  If caPath is empty,
// system CAs will be used.
func newTLSConfig(conf *configuration) (c *tls.Config, err error) {
	// Set default TLS min/max versions
	tlsMinVersion := tls.VersionTLS10
	tlsMaxVersion := tls.VersionTLS13

	switch conf.TLSMinVersion {
	case 1.1:
		tlsMinVersion = tls.VersionTLS11
	case 1.2:
		tlsMinVersion = tls.VersionTLS12
	case 1.3:
		tlsMinVersion = tls.VersionTLS13
	}

	switch conf.TLSMaxVersion {
	case 1.0:
		tlsMaxVersion = tls.VersionTLS10
	case 1.1:
		tlsMaxVersion = tls.VersionTLS11
	case 1.2:
		tlsMaxVersion = tls.VersionTLS12
	}

	cert, err := loadX509KeyPair(conf.TLSCertPath, conf.TLSKeyPath)
	if err != nil {
		return nil, fmt.Errorf("loading TLS cert: %s", err)
	}

	// #nosec G402 -- TLS MinVersion is configured by user.
	return &tls.Config{
		Certificates: []tls.Certificate{cert},
		MinVersion:   uint16(tlsMinVersion),
		MaxVersion:   uint16(tlsMaxVersion),
	}, nil
}

// loadX509KeyPair reads and parses a public/private key pair from a pair of
// files.  The files must contain PEM encoded data.  The certificate file may
// contain intermediate certificates following the leaf certificate to form a
// certificate chain.  On successful return, Certificate.Leaf will be nil
// because the parsed form of the certificate is not retained.
func loadX509KeyPair(certFile, keyFile string) (crt tls.Certificate, err error) {
	// #nosec G304 -- Trust the file path that is given in the configuration.
	certPEMBlock, err := os.ReadFile(certFile)
	if err != nil {
		return tls.Certificate{}, err
	}

	// #nosec G304 -- Trust the file path that is given in the configuration.
	keyPEMBlock, err := os.ReadFile(keyFile)
	if err != nil {
		return tls.Certificate{}, err
	}

	return tls.X509KeyPair(certPEMBlock, keyPEMBlock)
}
07070100000029000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002000000000dnsproxy-0.75.0/internal/dnsmsg0707010000002A000081A4000000000000000000000001679A649F00000D61000000000000000000000000000000000000002F00000000dnsproxy-0.75.0/internal/dnsmsg/constructor.go// Package dnsmsg contains common constants, functions, and types for inspecting
// and constructing DNS messages.
package dnsmsg

import (
	"strings"

	"github.com/miekg/dns"
)

// MessageConstructor creates DNS messages.
type MessageConstructor interface {
	// NewMsgNXDOMAIN creates a new response message replying to req with the
	// NXDOMAIN code.
	NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg)

	// NewMsgSERVFAIL creates a new response message replying to req with the
	// SERVFAIL code.
	NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg)

	// NewMsgNOTIMPLEMENTED creates a new response message replying to req with
	// the NOTIMPLEMENTED code.
	NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg)

	// NewMsgNODATA creates a new empty response message replying to req with
	// the NOERROR code.
	//
	// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
	NewMsgNODATA(req *dns.Msg) (resp *dns.Msg)
}

// DefaultMessageConstructor is a default implementation of
// [MessageConstructor].
type DefaultMessageConstructor struct{}

// type check
var _ MessageConstructor = DefaultMessageConstructor{}

// NewMsgNXDOMAIN implements the [MessageConstructor] interface for
// DefaultMessageConstructor.
func (DefaultMessageConstructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
	return reply(req, dns.RcodeNameError)
}

// NewMsgSERVFAIL implements the [MessageConstructor] interface for
// DefaultMessageConstructor.
func (DefaultMessageConstructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
	return reply(req, dns.RcodeServerFailure)
}

// NewMsgNOTIMPLEMENTED implements the [MessageConstructor] interface for
// DefaultMessageConstructor.
func (DefaultMessageConstructor) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
	resp = reply(req, dns.RcodeNotImplemented)

	// Most of the Internet and especially the inner core has an MTU of at least
	// 1500 octets.  Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet
	// is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)).
	//
	// See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17.
	const maxUDPPayload = 1452

	// NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so
	// explicitly set it.
	resp.SetEdns0(maxUDPPayload, false)

	return resp
}

// NewMsgNODATA implements the [MessageConstructor] interface for
// DefaultMessageConstructor.
func (DefaultMessageConstructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
	resp = reply(req, dns.RcodeSuccess)

	zone := req.Question[0].Name
	soa := &dns.SOA{
		// Values copied from verisign's nonexistent .com domain.
		//
		// Their exact values are not important in our use case because they are
		// used for domain transfers between primary/secondary DNS servers.
		Refresh: 1800,
		Retry:   60,
		Expire:  604800,
		Minttl:  86400,
		// copied from AdGuard DNS
		Ns:     "fake-for-negative-caching.adguard.com.",
		Serial: 100500,
		Mbox:   "hostmaster.",
		// rest is request-specific
		Hdr: dns.RR_Header{
			Name:   zone,
			Rrtype: dns.TypeSOA,
			Ttl:    10,
			Class:  dns.ClassINET,
		},
	}

	if !strings.HasPrefix(zone, ".") {
		soa.Mbox += zone
	}

	resp.Ns = append(resp.Ns, soa)

	return resp
}

// reply creates a new response message replying to req with the given code.
func reply(req *dns.Msg, code int) (resp *dns.Msg) {
	resp = (&dns.Msg{}).SetRcode(req, code)
	resp.RecursionAvailable = true

	return resp
}
0707010000002B000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002600000000dnsproxy-0.75.0/internal/dnsproxytest0707010000002C000081A4000000000000000000000001679A649F0000006A000000000000000000000000000000000000003600000000dnsproxy-0.75.0/internal/dnsproxytest/dnsproxytest.go// Package dnsproxytest provides a set of test utilities for the dnsproxy
// module.
package dnsproxytest
0707010000002D000081A4000000000000000000000001679A649F00000AA4000000000000000000000000000000000000003300000000dnsproxy-0.75.0/internal/dnsproxytest/interface.gopackage dnsproxytest

import (
	"github.com/miekg/dns"
)

// FakeUpstream is a fake [Upstream] implementation for tests.
//
// TODO(e.burkov):  Move this to the golibs some time later.
type FakeUpstream struct {
	OnAddress  func() (addr string)
	OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
	OnClose    func() (err error)
}

// Address implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Address() (addr string) {
	return u.OnAddress()
}

// Exchange implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	return u.OnExchange(req)
}

// Close implements the [Upstream] interface for *FakeUpstream.
func (u *FakeUpstream) Close() (err error) {
	return u.OnClose()
}

// TestMessageConstructor is a fake [dnsmsg.MessageConstructor] implementation
// for tests.
type TestMessageConstructor struct {
	OnNewMsgNXDOMAIN       func(req *dns.Msg) (resp *dns.Msg)
	OnNewMsgSERVFAIL       func(req *dns.Msg) (resp *dns.Msg)
	OnNewMsgNOTIMPLEMENTED func(req *dns.Msg) (resp *dns.Msg)
	OnNewMsgNODATA         func(req *dns.Msg) (resp *dns.Msg)
}

// NewTestMessageConstructor creates a new *TestMessageConstructor with all it's
// methods set to panic.
func NewTestMessageConstructor() (c *TestMessageConstructor) {
	return &TestMessageConstructor{
		OnNewMsgNXDOMAIN: func(_ *dns.Msg) (_ *dns.Msg) {
			panic("unexpected call of TestMessageConstructor.NewMsgNXDOMAIN")
		},
		OnNewMsgSERVFAIL: func(_ *dns.Msg) (_ *dns.Msg) {
			panic("unexpected call of TestMessageConstructor.NewMsgSERVFAIL")
		},
		OnNewMsgNOTIMPLEMENTED: func(_ *dns.Msg) (_ *dns.Msg) {
			panic("unexpected call of TestMessageConstructor.NewMsgNOTIMPLEMENTED")
		},
		OnNewMsgNODATA: func(_ *dns.Msg) (_ *dns.Msg) {
			panic("unexpected call of TestMessageConstructor.NewMsgNODATA")
		},
	}
}

// NewMsgNXDOMAIN implements the [MessageConstructor] interface for
// *TestMessageConstructor.
func (c *TestMessageConstructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
	return c.OnNewMsgNXDOMAIN(req)
}

// NewMsgSERVFAIL implements the [MessageConstructor] interface for
// *TestMessageConstructor.
func (c *TestMessageConstructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
	return c.OnNewMsgSERVFAIL(req)
}

// NewMsgNOTIMPLEMENTED implements the [MessageConstructor] interface for
// *TestMessageConstructor.
func (c *TestMessageConstructor) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
	return c.OnNewMsgNOTIMPLEMENTED(req)
}

// NewMsgNODATA implements the [MessageConstructor] interface for
// *TestMessageConstructor.
func (c *TestMessageConstructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
	return c.OnNewMsgNODATA(req)
}
0707010000002E000081A4000000000000000000000001679A649F00000162000000000000000000000000000000000000003800000000dnsproxy-0.75.0/internal/dnsproxytest/interface_test.gopackage dnsproxytest_test

import (
	"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
	"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
	"github.com/AdguardTeam/dnsproxy/upstream"
)

// type checks
var (
	_ upstream.Upstream         = (*dnsproxytest.FakeUpstream)(nil)
	_ dnsmsg.MessageConstructor = (*dnsproxytest.TestMessageConstructor)(nil)
)
0707010000002F000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002100000000dnsproxy-0.75.0/internal/handler07070100000030000081A4000000000000000000000001679A649F00000E60000000000000000000000000000000000000003000000000dnsproxy-0.75.0/internal/handler/constructor.gopackage handler

import (
	"net/netip"

	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/miekg/dns"
)

// messageConstructor is an extension of the [proxy.MessageConstructor]
// interface that also provides methods for creating DNS responses.
type messageConstructor interface {
	proxy.MessageConstructor

	// NewCompressedResponse creates a new compressed response message for req
	// with the given response code.
	NewCompressedResponse(req *dns.Msg, code int) (resp *dns.Msg)

	// NewPTRAnswer creates a new resource record for PTR response with the
	// given FQDN and PTR domain.  Arguments must be fully qualified domain
	// names.
	NewPTRAnswer(fqdn, ptrFQDN string) (ans *dns.PTR)

	// NewIPResponse creates a new A/AAAA response message for req with the
	// given IP addresses.  All IP addresses must be of the same family.
	NewIPResponse(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg)
}

// defaultConstructor is a wrapper for [proxy.MessageConstructor] that also
// implements the [messageConstructor] interface.
//
// TODO(e.burkov):  This implementation reflects the one from AdGuard Home,
// consider moving it to [golibs].
type defaultConstructor struct {
	proxy.MessageConstructor
}

// type check
var _ messageConstructor = defaultConstructor{}

// NewCompressedResponse implements the [messageConstructor] interface for
// defaultConstructor.
func (defaultConstructor) NewCompressedResponse(req *dns.Msg, code int) (resp *dns.Msg) {
	resp = reply(req, code)
	resp.Compress = true

	return resp
}

// NewPTRAnswer implements the [messageConstructor] interface for
// [defaultConstructor].
func (defaultConstructor) NewPTRAnswer(fqdn, ptrFQDN string) (ans *dns.PTR) {
	return &dns.PTR{
		Hdr: hdr(fqdn, dns.TypePTR),
		Ptr: dns.Fqdn(ptrFQDN),
	}
}

// NewIPResponse implements the [messageConstructor] interface for
// [defaultConstructor]
func (c defaultConstructor) NewIPResponse(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
	var ans []dns.RR
	switch req.Question[0].Qtype {
	case dns.TypeA:
		ans = genAnswersWithIPv4s(req, ips)
	case dns.TypeAAAA:
		for _, ip := range ips {
			if ip.Is6() {
				ans = append(ans, newAnswerAAAA(req, ip))
			}
		}
	default:
		// Go on and return an empty response.
	}

	resp = c.NewCompressedResponse(req, dns.RcodeSuccess)
	resp.Answer = ans

	return resp
}

// defaultResponseTTL is the default TTL for the DNS responses in seconds.
const defaultResponseTTL = 10

// hdr creates a new DNS header with the given name and RR type.
func hdr(name string, rrType uint16) (h dns.RR_Header) {
	return dns.RR_Header{
		Name:   name,
		Rrtype: rrType,
		Ttl:    defaultResponseTTL,
		Class:  dns.ClassINET,
	}
}

// reply creates a DNS response for req.
func reply(req *dns.Msg, code int) (resp *dns.Msg) {
	resp = (&dns.Msg{}).SetRcode(req, code)
	resp.RecursionAvailable = true

	return resp
}

// newAnswerA creates a DNS A answer for req with the given IP address.
func newAnswerA(req *dns.Msg, ip netip.Addr) (ans *dns.A) {
	return &dns.A{
		Hdr: hdr(req.Question[0].Name, dns.TypeA),
		A:   ip.AsSlice(),
	}
}

// newAnswerAAAA creates a DNS AAAA answer for req with the given IP address.
func newAnswerAAAA(req *dns.Msg, ip netip.Addr) (ans *dns.AAAA) {
	return &dns.AAAA{
		Hdr:  hdr(req.Question[0].Name, dns.TypeAAAA),
		AAAA: ip.AsSlice(),
	}
}

// genAnswersWithIPv4s generates DNS A answers provided IPv4 addresses.  If any
// of the IPs isn't an IPv4 address, genAnswersWithIPv4s logs a warning and
// returns nil,
func genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.RR) {
	for _, ip := range ips {
		if !ip.Is4() {
			return nil
		}

		ans = append(ans, newAnswerA(req, ip))
	}

	return ans
}
07070100000031000081A4000000000000000000000001679A649F00000787000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/handler/default.gopackage handler

import (
	"context"
	"log/slog"

	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/AdguardTeam/golibs/hostsfile"
)

// DefaultConfig is the configuration for [Default].
type DefaultConfig struct {
	// MessageConstructor constructs DNS messages.  It must not be nil.
	MessageConstructor proxy.MessageConstructor

	// Logger is the logger.  It must not be nil.
	Logger *slog.Logger

	// HostsFiles is the index containing the records of the hosts files.
	HostsFiles hostsfile.Storage

	// HaltIPv6 halts the processing of AAAA requests and makes the handler
	// reply with NODATA to them.
	HaltIPv6 bool
}

// Default implements the default configurable [proxy.RequestHandler].
type Default struct {
	messages     messageConstructor
	hosts        hostsfile.Storage
	logger       *slog.Logger
	isIPv6Halted bool
}

// NewDefault creates a new [Default] handler.
func NewDefault(conf *DefaultConfig) (d *Default) {
	mc, ok := conf.MessageConstructor.(messageConstructor)
	if !ok {
		mc = defaultConstructor{
			MessageConstructor: conf.MessageConstructor,
		}
	}

	return &Default{
		logger:       conf.Logger,
		isIPv6Halted: conf.HaltIPv6,
		messages:     mc,
		hosts:        conf.HostsFiles,
	}
}

// HandleRequest resolves the DNS request within proxyCtx.  It only calls
// [proxy.Proxy.Resolve] if the request isn't handled by any of the internal
// handlers.
func (h *Default) HandleRequest(p *proxy.Proxy, proxyCtx *proxy.DNSContext) (err error) {
	// TODO(e.burkov):  Use the [*context.Context] instead of
	// [*proxy.DNSContext] when the interface-based handler is implemented.
	ctx := context.TODO()

	h.logger.DebugContext(ctx, "handling request", "req", &proxyCtx.Req.Question[0])

	if proxyCtx.Res = h.haltAAAA(ctx, proxyCtx.Req); proxyCtx.Res != nil {
		return nil
	}

	if proxyCtx.Res = h.resolveFromHosts(ctx, proxyCtx.Req); proxyCtx.Res != nil {
		return nil
	}

	return p.Resolve(proxyCtx)
}
07070100000032000081A4000000000000000000000001679A649F0000130F000000000000000000000000000000000000003A00000000dnsproxy-0.75.0/internal/handler/default_internal_test.gopackage handler

import (
	"net"
	"net/netip"
	"os"
	"path"
	"path/filepath"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
	"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// TODO(e.burkov):  Remove when [hostsfile.DefaultStorage] stops using [log].
func TestMain(m *testing.M) {
	testutil.DiscardLogOutput(m)

	os.Exit(m.Run())
}

// TODO(e.burkov):  Add helpers to initialize [proxy.Proxy] to [dnsproxytest]
// and rewrite the tests.

// defaultTimeout is a default timeout for tests and contexts.
const defaultTimeout = 1 * time.Second

func TestDefault_haltAAAA(t *testing.T) {
	t.Parallel()

	reqA := (&dns.Msg{}).SetQuestion("domain.example.", dns.TypeA)
	reqAAAA := (&dns.Msg{}).SetQuestion("domain.example.", dns.TypeAAAA)

	nodataResp := (&dns.Msg{}).SetReply(reqA)

	messages := dnsproxytest.NewTestMessageConstructor()
	messages.OnNewMsgNODATA = func(_ *dns.Msg) (resp *dns.Msg) {
		return nodataResp
	}

	t.Run("disabled", func(t *testing.T) {
		t.Parallel()

		hdlr := NewDefault(&DefaultConfig{
			Logger:             slogutil.NewDiscardLogger(),
			MessageConstructor: messages,
			HaltIPv6:           false,
		})

		ctx := testutil.ContextWithTimeout(t, defaultTimeout)

		assert.Nil(t, hdlr.haltAAAA(ctx, reqA))
		assert.Nil(t, hdlr.haltAAAA(ctx, reqAAAA))
	})

	t.Run("enabled", func(t *testing.T) {
		t.Parallel()

		hdlr := NewDefault(&DefaultConfig{
			Logger:             slogutil.NewDiscardLogger(),
			MessageConstructor: messages,
			HaltIPv6:           true,
		})

		ctx := testutil.ContextWithTimeout(t, defaultTimeout)

		assert.Nil(t, hdlr.haltAAAA(ctx, reqA))
		assert.Equal(t, nodataResp, hdlr.haltAAAA(ctx, reqAAAA))
	})
}

func TestDefault_resolveFromHosts(t *testing.T) {
	t.Parallel()

	// TODO(e.burkov):  Use the one from [dnsproxytest].
	messages := dnsmsg.DefaultMessageConstructor{}

	relPath := path.Join("testdata", t.Name(), "hosts")
	absPath, err := filepath.Abs(path.Join("testdata", t.Name(), "hosts"))
	require.NoError(t, err)

	strg, err := ReadHosts([]string{absPath, relPath})
	require.NoError(t, err)

	hdlr := NewDefault(&DefaultConfig{
		MessageConstructor: messages,
		Logger:             slogutil.NewDiscardLogger(),
		HostsFiles:         strg,
		HaltIPv6:           true,
	})

	const (
		fqdnV4 = "ipv4.domain.example."
		fqdnV6 = "ipv6.domain.example."
	)

	var (
		addrV4 = netip.MustParseAddr("1.2.3.4")
		addrV6 = netip.MustParseAddr("2001:db8::1")

		reversedV4      = errors.Must(netutil.IPToReversedAddr(addrV4.AsSlice()))
		reversedV6      = errors.Must(netutil.IPToReversedAddr(addrV6.AsSlice()))
		unknownReversed = errors.Must(netutil.IPToReversedAddr(net.IP{4, 3, 2, 1}))
	)

	testCases := []struct {
		wantAns dns.RR
		req     *dns.Msg
		name    string
	}{{
		wantAns: &dns.A{
			Hdr: dns.RR_Header{
				Name:   fqdnV4,
				Rrtype: dns.TypeA,
				Class:  dns.ClassINET,
				Ttl:    10,
			},
			A: addrV4.AsSlice(),
		},
		req:  (&dns.Msg{}).SetQuestion(fqdnV4, dns.TypeA),
		name: "success_a",
	}, {
		wantAns: &dns.AAAA{
			Hdr: dns.RR_Header{
				Name:   fqdnV6,
				Rrtype: dns.TypeAAAA,
				Class:  dns.ClassINET,
				Ttl:    10,
			},
			AAAA: addrV6.AsSlice(),
		},
		req:  (&dns.Msg{}).SetQuestion(fqdnV6, dns.TypeAAAA),
		name: "success_aaaa",
	}, {
		wantAns: &dns.PTR{
			Hdr: dns.RR_Header{
				Name:   reversedV4,
				Rrtype: dns.TypePTR,
				Class:  dns.ClassINET,
				Ttl:    10,
			},
			Ptr: fqdnV4,
		},
		req:  (&dns.Msg{}).SetQuestion(reversedV4, dns.TypePTR),
		name: "success_ptr_v4",
	}, {
		wantAns: &dns.PTR{
			Hdr: dns.RR_Header{
				Name:   reversedV6,
				Rrtype: dns.TypePTR,
				Class:  dns.ClassINET,
				Ttl:    10,
			},
			Ptr: fqdnV6,
		},
		req:  (&dns.Msg{}).SetQuestion(reversedV6, dns.TypePTR),
		name: "success_ptr_v6",
	}, {
		wantAns: nil,
		req:     (&dns.Msg{}).SetQuestion("unknown.example", dns.TypeA),
		name:    "not_found_a",
	}, {
		wantAns: nil,
		req:     (&dns.Msg{}).SetQuestion("unknown.example", dns.TypeAAAA),
		name:    "not_found_aaaa",
	}, {
		wantAns: nil,
		req:     (&dns.Msg{}).SetQuestion(unknownReversed, dns.TypePTR),
		name:    "not_found_ptr",
	}, {
		wantAns: nil,
		req:     (&dns.Msg{}).SetQuestion("bad.ptr", dns.TypePTR),
		name:    "bad_ptr",
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			ctx := testutil.ContextWithTimeout(t, defaultTimeout)
			resp := hdlr.resolveFromHosts(ctx, tc.req)
			if tc.wantAns == nil {
				assert.Nil(t, resp)

				return
			}

			require.NotNil(t, resp)
			require.Len(t, resp.Answer, 1)
			assert.Equal(t, tc.wantAns, resp.Answer[0])
		})
	}
}
07070100000033000081A4000000000000000000000001679A649F0000006F000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/handler/handler.go// Package handler provides some customizable DNS request handling logic used in
// the proxy.
package handler
07070100000034000081A4000000000000000000000001679A649F00000D50000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/internal/handler/hosts.gopackage handler

import (
	"context"
	"fmt"
	"net/netip"
	"os"
	"slices"
	"strings"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/hostsfile"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

// emptyStorage is a [hostsfile.Storage] that contains no records.
//
// TODO(e.burkov):  Move to [hostsfile].
type emptyStorage [0]hostsfile.Record

// type check
var _ hostsfile.Storage = emptyStorage{}

// ByAddr implements the [hostsfile.Storage] interface for [emptyStorage].
func (emptyStorage) ByAddr(_ netip.Addr) (names []string) {
	return nil
}

// ByName implements the [hostsfile.Storage] interface for [emptyStorage].
func (emptyStorage) ByName(_ string) (addrs []netip.Addr) {
	return nil
}

// ReadHosts reads the hosts files from the file system and returns a storage
// with parsed records.  strg is always usable even if an error occurred.
func ReadHosts(paths []string) (strg hostsfile.Storage, err error) {
	// Don't check the error since it may only appear when any readers used.
	defaultStrg, _ := hostsfile.NewDefaultStorage()

	var errs []error
	for _, path := range paths {
		err = readHostsFile(defaultStrg, path)
		if err != nil {
			// Don't wrap the error since it's informative enough as is.
			errs = append(errs, err)
		}
	}

	// TODO(e.burkov):  Add method for length.
	isEmpty := true
	defaultStrg.RangeAddrs(func(_ string, _ []netip.Addr) (cont bool) {
		isEmpty = false

		return false
	})

	if isEmpty {
		return emptyStorage{}, errors.Join(errs...)
	}

	return defaultStrg, errors.Join(errs...)
}

// readHostsFile reads the hosts file at path and parses it into strg.
func readHostsFile(strg *hostsfile.DefaultStorage, path string) (err error) {
	// #nosec G304 -- Trust the file path from the configuration file.
	f, err := os.Open(path)
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return err
	}

	defer func() { err = errors.WithDeferred(err, f.Close()) }()

	err = hostsfile.Parse(strg, f, nil)
	if err != nil {
		return fmt.Errorf("parsing hosts file %q: %w", path, err)
	}

	return nil
}

// resolveFromHosts resolves the DNS query from the hosts file.  It fills the
// response with the A, AAAA, and PTR records from the hosts file.
func (h *Default) resolveFromHosts(ctx context.Context, req *dns.Msg) (resp *dns.Msg) {
	var addrs []netip.Addr
	var ptrs []string

	q := req.Question[0]
	name := strings.TrimSuffix(q.Name, ".")
	switch q.Qtype {
	case dns.TypeA:
		addrs = slices.Clone(h.hosts.ByName(name))
		addrs = slices.DeleteFunc(addrs, netip.Addr.Is6)
	case dns.TypeAAAA:
		addrs = slices.Clone(h.hosts.ByName(name))
		addrs = slices.DeleteFunc(addrs, netip.Addr.Is4)
	case dns.TypePTR:
		addr, err := netutil.IPFromReversedAddr(name)
		if err != nil {
			h.logger.DebugContext(ctx, "failed parsing ptr", slogutil.KeyError, err)

			return nil
		}

		ptrs = h.hosts.ByAddr(addr)
	default:
		return nil
	}

	switch {
	case len(addrs) > 0:
		resp = h.messages.NewIPResponse(req, addrs)
	case len(ptrs) > 0:
		resp = h.messages.NewCompressedResponse(req, dns.RcodeSuccess)
		name = req.Question[0].Name
		for _, ptr := range ptrs {
			resp.Answer = append(resp.Answer, h.messages.NewPTRAnswer(name, dns.Fqdn(ptr)))
		}
	default:
		h.logger.DebugContext(ctx, "no hosts records found", "name", name, "qtype", q.Qtype)
	}

	return resp
}
07070100000035000081A4000000000000000000000001679A649F000001E2000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/internal/handler/ipv6halt.gopackage handler

import (
	"context"

	"github.com/miekg/dns"
)

// haltAAAA halts the processing of AAAA requests if IPv6 is disabled.  req must
// not be nil.
func (h *Default) haltAAAA(ctx context.Context, req *dns.Msg) (resp *dns.Msg) {
	if h.isIPv6Halted && req.Question[0].Qtype == dns.TypeAAAA {
		h.logger.DebugContext(
			ctx,
			"ipv6 is disabled; replying with empty response",
			"req", req.Question[0].Name,
		)

		return h.messages.NewMsgNODATA(req)
	}

	return nil
}
07070100000036000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/internal/handler/testdata07070100000037000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000004700000000dnsproxy-0.75.0/internal/handler/testdata/TestDefault_resolveFromHosts07070100000038000081A4000000000000000000000001679A649F0000004A000000000000000000000000000000000000004D00000000dnsproxy-0.75.0/internal/handler/testdata/TestDefault_resolveFromHosts/hosts1.2.3.4     ipv4.domain.example
2001:db8::1 ipv6.domain.example
# comment
07070100000039000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002100000000dnsproxy-0.75.0/internal/netutil0707010000003A000081A4000000000000000000000001679A649F000002AC000000000000000000000000000000000000003100000000dnsproxy-0.75.0/internal/netutil/listenconfig.gopackage netutil

import (
	"log/slog"
	"net"
)

// ListenConfig returns the default [net.ListenConfig] used by the plain-DNS
// servers in this module.  l must not be nil.
//
// TODO(a.garipov): Add tests.
//
// TODO(a.garipov): Add an option to not set SO_REUSEPORT on Unix to prevent
// issues with OpenWrt.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/5872.
//
// TODO(a.garipov): DRY with AdGuard DNS when we can.
func ListenConfig(l *slog.Logger) (lc *net.ListenConfig) {
	return &net.ListenConfig{
		Control: listenControl{logger: l}.defaultListenControl,
	}
}

// listenControl is a wrapper struct with logger.
type listenControl struct {
	logger *slog.Logger
}
0707010000003B000081A4000000000000000000000001679A649F0000048A000000000000000000000000000000000000003600000000dnsproxy-0.75.0/internal/netutil/listenconfig_unix.go//go:build unix

package netutil

import (
	"fmt"
	"syscall"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"golang.org/x/sys/unix"
)

// defaultListenControl is used as a [net.ListenConfig.Control] function to set
// the SO_REUSEADDR and SO_REUSEPORT socket options on all sockets used by the
// DNS servers in this module.
func (lc listenControl) defaultListenControl(_, _ string, c syscall.RawConn) (err error) {
	var opErr error
	err = c.Control(func(fd uintptr) {
		opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
		if opErr != nil {
			opErr = fmt.Errorf("setting SO_REUSEADDR: %w", opErr)

			return
		}

		opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
		if opErr != nil {
			if errors.Is(opErr, unix.ENOPROTOOPT) {
				// Some Linux OSs do not seem to support SO_REUSEPORT, including
				// some varieties of OpenWrt.  Issue a warning.
				lc.logger.Warn("SO_REUSEPORT not supported", slogutil.KeyError, opErr)
				opErr = nil
			} else {
				opErr = fmt.Errorf("setting SO_REUSEPORT: %w", opErr)
			}
		}
	})

	return errors.WithDeferred(opErr, err)
}
0707010000003C000081A4000000000000000000000001679A649F000000F4000000000000000000000000000000000000003900000000dnsproxy-0.75.0/internal/netutil/listenconfig_windows.go//go:build windows

package netutil

import "syscall"

// defaultListenControl is nil on Windows, because it doesn't support
// SO_REUSEPORT.
func (listenControl) defaultListenControl(_, _ string, _ syscall.RawConn) (err error) {
	return nil
}
0707010000003D000081A4000000000000000000000001679A649F00000302000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/netutil/netutil.go// Package netutil contains network-related utilities common among dnsproxy
// packages.
//
// TODO(a.garipov): Move improved versions of these into netutil in module
// golibs.
package netutil

import (
	"net/netip"
	"strings"
)

// ParseSubnet parses s either as a CIDR prefix itself, or as an IP address,
// returning the corresponding single-IP CIDR prefix.
//
// TODO(e.burkov):  Replace usages with [netutil.Prefix].
func ParseSubnet(s string) (p netip.Prefix, err error) {
	if strings.Contains(s, "/") {
		p, err = netip.ParsePrefix(s)
		if err != nil {
			return netip.Prefix{}, err
		}
	} else {
		var ip netip.Addr
		ip, err = netip.ParseAddr(s)
		if err != nil {
			return netip.Prefix{}, err
		}

		p = netip.PrefixFrom(ip, ip.BitLen())
	}

	return p, nil
}
0707010000003E000081A4000000000000000000000001679A649F00000128000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/internal/netutil/paths.gopackage netutil

// DefaultHostsPaths returns the slice of default paths to system hosts files.
//
// TODO(s.chzhen):  Since [fs.FS] is no longer needed, update the
// [hostsfile.DefaultHostsPaths] from golibs.
func DefaultHostsPaths() (paths []string, err error) {
	return defaultHostsPaths()
}
0707010000003F000081A4000000000000000000000001679A649F000001C4000000000000000000000000000000000000002F00000000dnsproxy-0.75.0/internal/netutil/paths_unix.go//go:build unix

package netutil

import "github.com/AdguardTeam/golibs/hostsfile"

// defaultHostsPaths returns default paths to hosts files for UNIX.
func defaultHostsPaths() (paths []string, err error) {
	paths, err = hostsfile.DefaultHostsPaths()
	if err != nil {
		// Should not happen because error is always nil.
		panic(err)
	}

	res := make([]string, 0, len(paths))
	for _, p := range paths {
		res = append(res, "/"+p)
	}

	return res, nil
}
07070100000040000081A4000000000000000000000001679A649F000001B1000000000000000000000000000000000000003200000000dnsproxy-0.75.0/internal/netutil/paths_windows.go//go:build windows

package netutil

import (
	"fmt"
	"path"

	"golang.org/x/sys/windows"
)

// defaultHostsPaths returns default paths to hosts files for Windows.
func defaultHostsPaths() (paths []string, err error) {
	sysDir, err := windows.GetSystemDirectory()
	if err != nil {
		return []string{}, fmt.Errorf("getting system directory: %w", err)
	}

	p := path.Join(sysDir, "drivers", "etc", "hosts")

	return []string{p}, nil
}
07070100000041000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/internal/netutil/testdata07070100000042000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000003400000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts07070100000043000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000003D00000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts/bad_file07070100000044000081A4000000000000000000000001679A649F000000B6000000000000000000000000000000000000004300000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts/bad_file/hosts# comment about the following empty line

# comment about the above empty line

1.2.3.256 a.b # invalid address
1.2.3.4 a.123 # invalid top-level domain
1.2.3.4 .a.b  # empty domain
07070100000045000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000003E00000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts/good_file07070100000046000081A4000000000000000000000001679A649F0000033A000000000000000000000000000000000000004400000000dnsproxy-0.75.0/internal/netutil/testdata/TestHosts/good_file/hosts# IPv4

# 1st host.
0.0.0.1 Host.One

# 2nd host.
0.0.0.2 Host.Two

# 1st host full duplicate.
0.0.0.1 host.one

# 2nd host duplicate with new name.
0.0.0.2 host.two Host.New

# 1st host with foreign name.
0.0.0.1 host.new

# 2nd host new name.
0.0.0.2 Again.Host.Two

# Mapped

# 1st host.
::ffff:0.0.0.1 Host.One

# 2nd host.
::ffff:0.0.0.2 Host.Two

# 1st host full duplicate.
::ffff:0.0.0.1 host.one

# 2nd host duplicate with new name.
::ffff:0.0.0.2 host.two Host.New

# 1st host with foreign name.
::ffff:0.0.0.1 host.new

# 2nd host new name.
::ffff:0.0.0.2 Again.Host.Two

# IPv6

# 1st host.
::1 Host.One

# 2nd host.
::2 Host.Two

# 1st host full duplicate.
::1 host.one

# 2nd host duplicate with new name.
::2 host.two Host.New

# 1st host with foreign name.
::1 host.new

# 2nd host new name.
::2 Again.Host.Two
07070100000047000081A4000000000000000000000001679A649F0000042A000000000000000000000000000000000000002800000000dnsproxy-0.75.0/internal/netutil/udp.gopackage netutil

import (
	"net"
	"net/netip"
)

// UDPGetOOBSize returns maximum size of the received OOB data.
func UDPGetOOBSize() (oobSize int) {
	return udpGetOOBSize()
}

// UDPSetOptions sets flag options on a UDP socket to be able to receive the
// necessary OOB data.
func UDPSetOptions(c *net.UDPConn) (err error) {
	return udpSetOptions(c)
}

// UDPRead reads the message from conn using buf and receives a control-message
// payload of size udpOOBSize from it.  It returns the number of bytes copied
// into buf and the source address of the message.
//
// TODO(s.chzhen):  Consider using netip.Addr.
func UDPRead(
	conn *net.UDPConn,
	buf []byte,
	udpOOBSize int,
) (n int, localIP netip.Addr, remoteAddr *net.UDPAddr, err error) {
	return udpRead(conn, buf, udpOOBSize)
}

// UDPWrite writes the data to the remoteAddr using conn.
//
// TODO(s.chzhen):  Consider using netip.Addr.
func UDPWrite(
	data []byte,
	conn *net.UDPConn,
	remoteAddr *net.UDPAddr,
	localIP netip.Addr,
) (n int, err error) {
	return udpWrite(data, conn, remoteAddr, localIP)
}
07070100000048000081A4000000000000000000000001679A649F00000856000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/internal/netutil/udp_unix.go//go:build unix

package netutil

import (
	"fmt"
	"net"
	"net/netip"

	"github.com/AdguardTeam/golibs/netutil"
	"golang.org/x/net/ipv4"
	"golang.org/x/net/ipv6"
)

// These are the set of socket option flags for configuring an IPv[46] UDP
// connection to receive an appropriate OOB data.  For both versions the flags
// are:
//
//   - FlagDst
//   - FlagInterface
const (
	ipv4Flags ipv4.ControlFlags = ipv4.FlagDst | ipv4.FlagInterface
	ipv6Flags ipv6.ControlFlags = ipv6.FlagDst | ipv6.FlagInterface
)

// udpGetOOBSize obtains the destination IP from OOB data.
func udpGetOOBSize() (oobSize int) {
	return max(len(ipv4.NewControlMessage(ipv4Flags)), len(ipv6.NewControlMessage(ipv6Flags)))
}

func udpSetOptions(c *net.UDPConn) (err error) {
	err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6Flags, true)
	err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4Flags, true)
	if err6 != nil && err4 != nil {
		return fmt.Errorf("failed to call SetControlMessage: ipv4: %v; ipv6: %v", err4, err6)
	}

	return nil
}

func udpGetDstFromOOB(oob []byte) (dst netip.Addr, err error) {
	cm6 := &ipv6.ControlMessage{}
	if cm6.Parse(oob) == nil && cm6.Dst != nil {
		// Linux maps IPv4 addresses to IPv6 ones by default, so we can get an
		// IPv4 dst from an IPv6 control-message.
		return netutil.IPToAddrNoMapped(cm6.Dst)
	}

	cm4 := &ipv4.ControlMessage{}
	if cm4.Parse(oob) == nil && cm4.Dst != nil {
		return netutil.IPToAddr(cm4.Dst, netutil.AddrFamilyIPv4)
	}

	return netip.Addr{}, nil
}

func udpRead(
	c *net.UDPConn,
	buf []byte,
	udpOOBSize int,
) (n int, localIP netip.Addr, remoteAddr *net.UDPAddr, err error) {
	var oobn int
	oob := make([]byte, udpOOBSize)
	n, oobn, _, remoteAddr, err = c.ReadMsgUDP(buf, oob)
	if err != nil {
		return -1, netip.Addr{}, nil, err
	}

	localIP, err = udpGetDstFromOOB(oob[:oobn])
	if err != nil {
		return -1, netip.Addr{}, nil, err
	}

	return n, localIP, remoteAddr, nil
}

func udpWrite(
	data []byte,
	conn *net.UDPConn,
	remoteAddr *net.UDPAddr,
	localIP netip.Addr,
) (n int, err error) {
	n, _, err = conn.WriteMsgUDP(data, udpMakeOOBWithSrc(localIP), remoteAddr)

	return n, err
}
07070100000049000081A4000000000000000000000001679A649F00000229000000000000000000000000000000000000003000000000dnsproxy-0.75.0/internal/netutil/udp_windows.go//go:build windows

package netutil

import (
	"net"
	"net/netip"
)

func udpGetOOBSize() int {
	return 0
}

func udpSetOptions(c *net.UDPConn) error {
	return nil
}

func udpRead(c *net.UDPConn, buf []byte, _ int) (int, netip.Addr, *net.UDPAddr, error) {
	n, addr, err := c.ReadFrom(buf)
	var udpAddr *net.UDPAddr
	if addr != nil {
		udpAddr = addr.(*net.UDPAddr)
	}

	return n, netip.Addr{}, udpAddr, err
}

func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, _ netip.Addr) (int, error) {
	return conn.WriteTo(bytes, remoteAddr)
}
0707010000004A000081A4000000000000000000000001679A649F0000026F000000000000000000000000000000000000003200000000dnsproxy-0.75.0/internal/netutil/udpoob_darwin.go//go:build darwin

package netutil

import (
	"net/netip"

	"golang.org/x/net/ipv6"
)

// udpMakeOOBWithSrc makes the OOB data with the specified source IP.
func udpMakeOOBWithSrc(ip netip.Addr) (b []byte) {
	if ip.Is4() {
		// Do not set the IPv4 source address via OOB, because it can cause the
		// address to become unspecified on darwin.
		//
		// See https://github.com/AdguardTeam/AdGuardHome/issues/2807.
		//
		// TODO(e.burkov): Develop a workaround to make it write OOB only when
		// listening on an unspecified address.
		return []byte{}
	}

	return (&ipv6.ControlMessage{
		Src: ip.AsSlice(),
	}).Marshal()
}
0707010000004B000081A4000000000000000000000001679A649F00000186000000000000000000000000000000000000003200000000dnsproxy-0.75.0/internal/netutil/udpoob_others.go//go:build !darwin

package netutil

import (
	"net/netip"

	"golang.org/x/net/ipv4"
	"golang.org/x/net/ipv6"
)

// udpMakeOOBWithSrc makes the OOB data with the specified source IP.
func udpMakeOOBWithSrc(ip netip.Addr) (b []byte) {
	if ip.Is4() {
		return (&ipv4.ControlMessage{
			Src: ip.AsSlice(),
		}).Marshal()
	}

	return (&ipv6.ControlMessage{
		Src: ip.AsSlice(),
	}).Marshal()
}
0707010000004C000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/internal/tools0707010000004D000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000002100000000dnsproxy-0.75.0/internal/version0707010000004E000081A4000000000000000000000001679A649F00000351000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/internal/version/version.go// Package version contains dnsproxy version information.
package version

// Versions

// These are set by the linker.  Unfortunately, we cannot set constants during
// linking, and Go doesn't have a concept of immutable variables, so to be
// thorough we have to only export them through getters.
var (
	branch     string
	committime string
	revision   string
	version    string
)

// Branch returns the compiled-in value of the Git branch.
func Branch() (b string) {
	return branch
}

// CommitTime returns the compiled-in value of the build time as a string.
func CommitTime() (t string) {
	return committime
}

// Revision returns the compiled-in value of the Git revision.
func Revision() (r string) {
	return revision
}

// Version returns the compiled-in value of the build version as a string.
func Version() (v string) {
	return version
}
0707010000004F000081A4000000000000000000000001679A649F00000066000000000000000000000000000000000000001800000000dnsproxy-0.75.0/main.gopackage main

import (
	"github.com/AdguardTeam/dnsproxy/internal/cmd"
)

func main() {
	cmd.Main()
}
07070100000050000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001600000000dnsproxy-0.75.0/proxy07070100000051000081A4000000000000000000000001679A649F00000A46000000000000000000000000000000000000002700000000dnsproxy-0.75.0/proxy/beforerequest.gopackage proxy

import (
	"fmt"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
)

// BeforeRequestError is an error that signals that the request should be
// responded with the given response message.
type BeforeRequestError struct {
	// Err is the error that caused the response.  It must not be nil.
	Err error

	// Response is the response message to be sent to the client.  It must be a
	// valid response message.
	Response *dns.Msg
}

// type check
var _ error = (*BeforeRequestError)(nil)

// Error implements the [error] interface for *BeforeRequestError.
func (e *BeforeRequestError) Error() (msg string) {
	return fmt.Sprintf("%s; respond with %s", e.Err, dns.RcodeToString[e.Response.Rcode])
}

// type check
var _ errors.Wrapper = (*BeforeRequestError)(nil)

// Unwrap implements the [errors.Wrapper] interface for *BeforeRequestError.
func (e *BeforeRequestError) Unwrap() (unwrapped error) {
	return e.Err
}

// BeforeRequestHandler is an object that can handle the request before it's
// processed by [Proxy].
type BeforeRequestHandler interface {
	// HandleBefore is called before each DNS request is started processing.
	// The passed [DNSContext] contains the Req, Addr, and IsLocalClient fields
	// set accordingly.
	//
	// If returned err is a [BeforeRequestError], the given response message is
	// used.  If err is nil, the request is processed further.  [Proxy] assumes
	// a handler itself doesn't set the [DNSContext.Res] field.
	HandleBefore(p *Proxy, dctx *DNSContext) (err error)
}

// noopRequestHandler is a no-op implementation of [BeforeRequestHandler] that
// always returns nil.
type noopRequestHandler struct{}

// type check
var _ BeforeRequestHandler = noopRequestHandler{}

// HandleBefore implements the [BeforeRequestHandler] interface for
// noopRequestHandler.
func (noopRequestHandler) HandleBefore(_ *Proxy, _ *DNSContext) (err error) {
	return nil
}

// handleBefore calls the [BeforeRequestHandler] if it's set.  If the returned
// error is nil, it returns true and the request is processed further.  If the
// returned error has type [BeforeRequestError], the specified response is sent
// to the client.  Otherwise, the request just ignored.
func (p *Proxy) handleBefore(d *DNSContext) (cont bool) {
	err := p.beforeRequestHandler.HandleBefore(p, d)
	if err == nil {
		return true
	}

	p.logger.Debug("handling before request", slogutil.KeyError, err)

	if befReqErr := (&BeforeRequestError{}); errors.As(err, &befReqErr) {
		d.Res = befReqErr.Response

		p.logDNSMessage(d.Res)
		p.respond(d)
	}

	return false
}
07070100000052000081A4000000000000000000000001679A649F00000D14000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/proxy/beforerequest_test.gopackage proxy

import (
	"context"
	"fmt"
	"net"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// testBeforeRequestHandler is a mock before request handler implementation to
// simplify testing.
type testBeforeRequestHandler struct {
	onHandleBefore func(p *Proxy, dctx *DNSContext) (err error)
}

// type check
var _ BeforeRequestHandler = (*testBeforeRequestHandler)(nil)

// HandleBefore implements the [BeforeRequestHandler] interface for
// *testBeforeRequestHandler.
func (h *testBeforeRequestHandler) HandleBefore(p *Proxy, dctx *DNSContext) (err error) {
	return h.onHandleBefore(p, dctx)
}

func TestProxy_HandleDNSRequest_beforeRequestHandler(t *testing.T) {
	t.Parallel()

	const (
		allowedID = iota
		droppedID
		errorID
	)

	allowedRequest := (&dns.Msg{}).SetQuestion("allowed.", dns.TypeA)
	allowedRequest.Id = allowedID
	allowedResponse := (&dns.Msg{}).SetReply(allowedRequest)

	droppedRequest := (&dns.Msg{}).SetQuestion("dropped.", dns.TypeA)
	droppedRequest.Id = droppedID

	errorRequest := (&dns.Msg{}).SetQuestion("error.", dns.TypeA)
	errorRequest.Id = errorID
	errorResponse := (&dns.Msg{}).SetReply(errorRequest)

	p := mustNew(t, &Config{
		Logger:        slogutil.NewDiscardLogger(),
		TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig: &UpstreamConfig{
			Upstreams: []upstream.Upstream{&fakeUpstream{
				onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
					return allowedResponse.Copy(), nil
				},
				onAddress: func() (addr string) { return "general" },
				onClose:   func() (err error) { return nil },
			}},
		},
		TrustedProxies: defaultTrustedProxies,
		PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
		BeforeRequestHandler: &testBeforeRequestHandler{
			onHandleBefore: func(p *Proxy, dctx *DNSContext) (err error) {
				switch dctx.Req.Id {
				case allowedID:
					return nil
				case droppedID:
					return errors.Error("just drop")
				case errorID:
					return &BeforeRequestError{
						Err:      errors.Error("just error"),
						Response: errorResponse,
					}
				default:
					panic(fmt.Sprintf("unexpected request id: %d", dctx.Req.Id))
				}
			},
		},
	})
	ctx := context.Background()
	require.NoError(t, p.Start(ctx))
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) })

	client := &dns.Client{
		Net:     string(ProtoTCP),
		Timeout: 200 * time.Millisecond,
	}
	addr := p.Addr(ProtoTCP).String()

	t.Run("allowed", func(t *testing.T) {
		t.Parallel()

		resp, _, err := client.Exchange(allowedRequest, addr)
		require.NoError(t, err)
		assert.Equal(t, allowedResponse, resp)
	})

	t.Run("dropped", func(t *testing.T) {
		t.Parallel()

		resp, _, err := client.Exchange(droppedRequest, addr)

		wantErr := &net.OpError{}
		require.ErrorAs(t, err, &wantErr)
		assert.True(t, wantErr.Timeout())

		assert.Nil(t, resp)
	})

	t.Run("error", func(t *testing.T) {
		t.Parallel()

		resp, _, err := client.Exchange(errorRequest, addr)
		require.NoError(t, err)
		assert.Equal(t, errorResponse, resp)
	})
}
07070100000053000081A4000000000000000000000001679A649F000002AF000000000000000000000000000000000000002700000000dnsproxy-0.75.0/proxy/bogusnxdomain.gopackage proxy

import (
	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

// isBogusNXDomain returns true if m contains at least a single IP address in
// the Answer section contained in BogusNXDomain subnets of p.
func (p *Proxy) isBogusNXDomain(m *dns.Msg) (ok bool) {
	if m == nil || len(p.BogusNXDomain) == 0 || len(m.Question) == 0 {
		return false
	} else if qt := m.Question[0].Qtype; qt != dns.TypeA && qt != dns.TypeAAAA {
		return false
	}

	set := netutil.SliceSubnetSet(p.BogusNXDomain)
	for _, rr := range m.Answer {
		ip := proxyutil.IPFromRR(rr)
		if set.Contains(ip) {
			return true
		}
	}

	return false
}
07070100000054000081A4000000000000000000000001679A649F00000B72000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/proxy/bogusnxdomain_test.gopackage proxy

import (
	"context"
	"net"
	"net/netip"
	"testing"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestProxy_IsBogusNXDomain(t *testing.T) {
	prx := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		CacheEnabled:           true,
		BogusNXDomain: []netip.Prefix{
			netip.MustParsePrefix("4.3.2.1/24"),
			netip.MustParsePrefix("1.2.3.4/8"),
			netip.MustParsePrefix("10.11.12.13/32"),
			netip.MustParsePrefix("102:304:506:708:90a:b0c:d0e:f10/120"),
		},
	})

	testCases := []struct {
		name      string
		ans       []dns.RR
		wantRcode int
	}{{
		name: "bogus_subnet",
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
			A:   net.ParseIP("4.3.2.1"),
		}},
		wantRcode: dns.RcodeNameError,
	}, {
		name: "bogus_big_subnet",
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
			A:   net.ParseIP("1.254.254.254"),
		}},
		wantRcode: dns.RcodeNameError,
	}, {
		name: "bogus_single_ip",
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
			A:   net.ParseIP("10.11.12.13"),
		}},
		wantRcode: dns.RcodeNameError,
	}, {
		name: "bogus_6",
		ans: []dns.RR{&dns.AAAA{
			Hdr:  dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10},
			AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 99},
		}},
		wantRcode: dns.RcodeNameError,
	}, {
		name: "non-bogus",
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
			A:   net.ParseIP("10.11.12.14"),
		}},
		wantRcode: dns.RcodeSuccess,
	}, {
		name: "non-bogus_6",
		ans: []dns.RR{&dns.AAAA{
			Hdr:  dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10},
			AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 15},
		}},
		wantRcode: dns.RcodeSuccess,
	}}

	u := testUpstream{}
	prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u}

	ctx := context.Background()
	err := prx.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return prx.Shutdown(ctx) })

	d := &DNSContext{
		Req: newHostTestMessage("host"),
	}

	for _, tc := range testCases {
		u.ans = tc.ans

		t.Run(tc.name, func(t *testing.T) {
			err = prx.Resolve(d)
			require.NoError(t, err)
			require.NotNil(t, d.Res)

			assert.Equal(t, tc.wantRcode, d.Res.Rcode)
		})
	}
}
07070100000055000081A4000000000000000000000001679A649F00003F27000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/cache.gopackage proxy

import (
	"bytes"
	"encoding/binary"
	"log/slog"
	"math"
	"net"
	"slices"
	"strings"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	glcache "github.com/AdguardTeam/golibs/cache"
	"github.com/AdguardTeam/golibs/mathutil"
	"github.com/miekg/dns"
)

// defaultCacheSize is the size of cache in bytes by default.
const defaultCacheSize = 64 * 1024

// cache is used to cache requests and used upstreams.
type cache struct {
	// itemsLock protects requests cache.
	itemsLock *sync.RWMutex

	// itemsWithSubnetLock protects requests cache.
	itemsWithSubnetLock *sync.RWMutex

	// items is the requests cache.
	items glcache.Cache

	// itemsWithSubnet is the requests cache.
	itemsWithSubnet glcache.Cache

	// optimistic defines if the cache should return expired items and resolve
	// those again.
	optimistic bool
}

// cacheItem is a single cache entry.  It's a helper type to aggregate the
// item-specific logic.
type cacheItem struct {
	// m contains the cached response.
	m *dns.Msg

	// u contains an address of the upstream which resolved m.
	u string

	// ttl is the time-to-live value for the item.  Should be set before calling
	// [cacheItem.pack].
	ttl uint32
}

// respToItem converts the pair of the response and upstream resolved the one
// into item for storing it in cache.  l must not be nil.
func (c *cache) respToItem(m *dns.Msg, u upstream.Upstream, l *slog.Logger) (item *cacheItem) {
	ttl := cacheTTL(m, l)
	if ttl == 0 {
		return nil
	}

	upsAddr := ""
	if u != nil {
		upsAddr = u.Address()
	}

	return &cacheItem{
		m:   m,
		u:   upsAddr,
		ttl: ttl,
	}
}

const (
	// packedMsgLenSz is the exact length of byte slice capable to store the
	// length of packed DNS message.  It's essentially the size of a uint16.
	packedMsgLenSz = 2
	// expTimeSz is the exact length of byte slice capable to store the
	// expiration time the response.  It's essentially the size of a uint32.
	expTimeSz = 4

	// minPackedLen is the minimum length of the packed cacheItem.
	minPackedLen = expTimeSz + packedMsgLenSz
)

// pack converts the ci into bytes slice.
func (ci *cacheItem) pack() (packed []byte) {
	pm, _ := ci.m.Pack()
	pmLen := len(pm)
	packed = make([]byte, minPackedLen, minPackedLen+pmLen+len(ci.u))

	// Put expiration time.
	binary.BigEndian.PutUint32(packed, uint32(time.Now().Unix())+ci.ttl)

	// Put the length of the packed message.
	binary.BigEndian.PutUint16(packed[expTimeSz:], uint16(pmLen))

	// Put the packed message itself.
	packed = append(packed, pm...)

	// Put the address of the upstream.
	packed = append(packed, ci.u...)

	return packed
}

// optimisticTTL is the default TTL for expired cached responses in seconds.
const optimisticTTL = 10

// unpackItem converts the data into cacheItem using req as a request message.
// expired is true if the item exists but expired.  The expired cached items are
// only returned if c is optimistic.  req must not be nil.
func (c *cache) unpackItem(data []byte, req *dns.Msg) (ci *cacheItem, expired bool) {
	if len(data) < minPackedLen {
		return nil, false
	}

	b := bytes.NewBuffer(data)
	expire := int64(binary.BigEndian.Uint32(b.Next(expTimeSz)))
	now := time.Now().Unix()
	var ttl uint32
	if expired = expire <= now; expired {
		if !c.optimistic {
			return nil, expired
		}

		ttl = optimisticTTL
	} else {
		ttl = uint32(expire - now)
	}

	l := int(binary.BigEndian.Uint16(b.Next(packedMsgLenSz)))
	if l == 0 {
		return nil, expired
	}

	m := &dns.Msg{}
	if m.Unpack(b.Next(l)) != nil {
		return nil, expired
	}

	res := (&dns.Msg{}).SetRcode(req, m.Rcode)
	res.AuthenticatedData = m.AuthenticatedData
	res.RecursionAvailable = m.RecursionAvailable

	var doBit bool
	if o := req.IsEdns0(); o != nil {
		doBit = o.Do()
	}

	// Don't return OPT records from cache since it's deprecated by RFC 6891.
	// If the request has DO bit set we only remove all the OPT RRs, and also
	// all DNSSEC RRs otherwise.
	filterMsg(res, m, req.AuthenticatedData, doBit, ttl)

	return &cacheItem{
		m: res,
		u: string(b.Next(b.Len())),
	}, expired
}

// initCache initializes cache if it's enabled.
func (p *Proxy) initCache() {
	if !p.CacheEnabled {
		p.logger.Info("cache disabled")

		return
	}

	size := p.CacheSizeBytes
	p.logger.Info("cache enabled", "size", size)

	p.cache = newCache(size, p.EnableEDNSClientSubnet, p.CacheOptimistic)
	p.shortFlighter = newOptimisticResolver(p)
}

// newCache returns a properly initialized cache.  logger must not be nil.
func newCache(size int, withECS, optimistic bool) (c *cache) {
	c = &cache{
		itemsLock:           &sync.RWMutex{},
		itemsWithSubnetLock: &sync.RWMutex{},
		items:               createCache(size),
		optimistic:          optimistic,
	}

	if withECS {
		c.itemsWithSubnet = createCache(size)
	}

	return c
}

// get returns cached item for the req if it's found.  expired is true if the
// item's TTL is expired.  key is the resulting key for req.  It's returned to
// avoid recalculating it afterwards.
func (c *cache) get(req *dns.Msg) (ci *cacheItem, expired bool, key []byte) {
	c.itemsLock.RLock()
	defer c.itemsLock.RUnlock()

	if !canLookUpInCache(c.items, req) {
		return nil, false, nil
	}

	key = msgToKey(req)
	data := c.items.Get(key)
	if data == nil {
		return nil, false, key
	}

	if ci, expired = c.unpackItem(data, req); ci == nil {
		c.items.Del(key)
	}

	return ci, expired, key
}

// getWithSubnet returns cached item for the req if it's found by n.  expired
// is true if the item's TTL is expired.  k is the resulting key for req.  It's
// returned to avoid recalculating it afterwards.
//
// Note that a slow longest-prefix-match algorithm is used, so cache searches
// are performed up to mask+1 times.
func (c *cache) getWithSubnet(req *dns.Msg, n *net.IPNet) (ci *cacheItem, expired bool, k []byte) {
	c.itemsWithSubnetLock.RLock()
	defer c.itemsWithSubnetLock.RUnlock()

	if !canLookUpInCache(c.itemsWithSubnet, req) {
		return nil, false, nil
	}

	ecsIP := n.IP.Mask(n.Mask)
	ipLen := len(ecsIP)
	m, _ := n.Mask.Size()

	k = msgToKeyWithSubnet(req, ecsIP, m)
	data := c.itemsWithSubnet.Get(k)

	// In order to reduce allocations we apply mask on bits level.  As the key
	// k has ecsIP in bytes slice representation, each iteration we can just
	// clear one bit in the end of it by applying the bitmask.
	for bitmask := ^byte(0); m >= 0 && data == nil; m-- {
		// Set mask identification byte in the key.
		k[keyMaskIndex] = byte(m)

		// In case mask is zero, the key doesn't have IP in it.
		if m == 0 {
			k = slices.Delete(k, keyIPIndex, keyIPIndex+ipLen)
			data = c.itemsWithSubnet.Get(k)

			continue
		}

		// Shift or renew bitmask.
		if m%8 == 0 {
			bitmask = ^byte(0)
		} else {
			bitmask <<= 1
		}

		// Clear the last non-zero bit in the byte of the IP address.
		k[keyIPIndex+m/8] &= bitmask

		data = c.itemsWithSubnet.Get(k)
	}

	if data == nil {
		return nil, false, k
	}

	if ci, expired = c.unpackItem(data, req); ci == nil {
		c.itemsWithSubnet.Del(k)
	}

	return ci, expired, k
}

// canLookUpInCache returns true if these parameters could be used to make a
// cache lookup.
func canLookUpInCache(cache glcache.Cache, req *dns.Msg) (ok bool) {
	return cache != nil && req != nil && len(req.Question) == 1
}

// createCache returns new Cache with the given cacheSize.
func createCache(cacheSize int) (glc glcache.Cache) {
	conf := glcache.Config{
		MaxSize:   defaultCacheSize,
		EnableLRU: true,
	}

	if cacheSize > 0 {
		conf.MaxSize = uint(cacheSize)
	}

	return glcache.New(conf)
}

// set stores response and upstream in the cache.  l must not be nil.
func (c *cache) set(m *dns.Msg, u upstream.Upstream, l *slog.Logger) {
	item := c.respToItem(m, u, l)
	if item == nil {
		return
	}

	key := msgToKey(m)
	packed := item.pack()

	c.itemsLock.Lock()
	defer c.itemsLock.Unlock()

	c.items.Set(key, packed)
}

// setWithSubnet stores response and upstream with subnet in the cache.  The
// given subnet mask and IP address are used to calculate the cache key.  l must
// not be nil.
func (c *cache) setWithSubnet(m *dns.Msg, u upstream.Upstream, subnet *net.IPNet, l *slog.Logger) {
	item := c.respToItem(m, u, l)
	if item == nil {
		return
	}

	pref, _ := subnet.Mask.Size()
	key := msgToKeyWithSubnet(m, subnet.IP.Mask(subnet.Mask), pref)
	packed := item.pack()

	c.itemsWithSubnetLock.Lock()
	defer c.itemsWithSubnetLock.Unlock()

	c.itemsWithSubnet.Set(key, packed)
}

// clearItems empties the simple cache.
func (c *cache) clearItems() {
	c.itemsLock.Lock()
	defer c.itemsLock.Unlock()

	c.items.Clear()
}

// clearItemsWithSubnet empties the subnet cache, if any.
func (c *cache) clearItemsWithSubnet() {
	if c.itemsWithSubnet == nil {
		// ECS disabled, return immediately.
		return
	}

	c.itemsWithSubnetLock.Lock()
	defer c.itemsWithSubnetLock.Unlock()

	c.itemsWithSubnet.Clear()
}

// cacheTTL returns the number of seconds for which m is valid to be cached.
// For negative answers it follows RFC 2308 on how to cache NXDOMAIN and NODATA
// kinds of responses.  l must not be nil.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-2.1,
// https://datatracker.ietf.org/doc/html/rfc2308#section-2.2.
func cacheTTL(m *dns.Msg, l *slog.Logger) (ttl uint32) {
	switch {
	case m == nil:
		return 0
	case m.Truncated:
		l.Debug("truncated message; not caching")

		return 0
	case len(m.Question) != 1:
		l.Debug("message with wrong number of questions; not caching")

		return 0
	default:
		ttl = calculateTTL(m)
		if ttl == 0 {
			l.Debug("ttl calculated to be 0; not caching")

			return 0
		}
	}

	switch rcode := m.Rcode; rcode {
	case dns.RcodeSuccess:
		if isCacheableSucceded(m) {
			return ttl
		}

		l.Debug("not a cacheable noerror response; not caching")
	case dns.RcodeNameError:
		if isCacheableNegative(m) {
			return ttl
		}

		l.Debug("not a cacheable nxdomain response; not caching")
	case dns.RcodeServerFailure:
		return ttl
	default:
		l.Debug("response code %s; not caching", "rcode", dns.RcodeToString[rcode])
	}

	return 0
}

// hasIPAns check the m for containing at least one A or AAAA RR in answer
// section.
func hasIPAns(m *dns.Msg) (ok bool) {
	for _, rr := range m.Answer {
		if t := rr.Header().Rrtype; t == dns.TypeA || t == dns.TypeAAAA {
			return true
		}
	}

	return false
}

// isCacheableSucceded returns true if m contains useful data to be cached
// treating it as a successful response.
func isCacheableSucceded(m *dns.Msg) (ok bool) {
	qType := m.Question[0].Qtype

	return (qType != dns.TypeA && qType != dns.TypeAAAA) || hasIPAns(m) || isCacheableNegative(m)
}

// isCacheableNegative returns true if m's header has at least a single SOA RR
// and no NS records so that it can be declared authoritative.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-5 for the
// information on the responses from the authoritative server that should be
// cached by the forwarder.
func isCacheableNegative(m *dns.Msg) (ok bool) {
	for _, rr := range m.Ns {
		switch rr.Header().Rrtype {
		case dns.TypeSOA:
			ok = true
		case dns.TypeNS:
			return false
		default:
			// Go on.
		}
	}

	return ok
}

// ServFailMaxCacheTTL is the maximum time-to-live value for caching
// SERVFAIL responses in seconds.  It's consistent with the upper constraint
// of 5 minutes given by RFC 2308.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-7.1.
const ServFailMaxCacheTTL = 30

// calculateTTL returns the number of seconds for which m could be cached.  It's
// usually the lowest TTL among all m's resource records.  It returns 0 if m
// isn't cacheable according to it's contents.
func calculateTTL(m *dns.Msg) (ttl uint32) {
	// Use the maximum value as a guard value.  If the inner loop is entered,
	// it's going to be rewritten with an actual TTL value that is lower than
	// MaxUint32.  If the inner loop isn't entered, catch that and return zero.
	ttl = math.MaxUint32
	for _, rrset := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
		for _, rr := range rrset {
			ttl = minTTL(rr.Header(), ttl)
			if ttl == 0 {
				return 0
			}
		}
	}

	switch {
	case m.Rcode == dns.RcodeServerFailure && ttl > ServFailMaxCacheTTL:
		return ServFailMaxCacheTTL
	case ttl == math.MaxUint32:
		return 0
	default:
		return ttl
	}
}

// minTTL returns the minimum of h's ttl and the passed ttl.
func minTTL(h *dns.RR_Header, ttl uint32) uint32 {
	switch {
	case h.Rrtype == dns.TypeOPT:
		return ttl
	case h.Ttl < ttl:
		return h.Ttl
	default:
		return ttl
	}
}

// Updates a given TTL to fall within the range specified by the cacheMinTTL and
// cacheMaxTTL settings.
func respectTTLOverrides(ttl, cacheMinTTL, cacheMaxTTL uint32) uint32 {
	if ttl < cacheMinTTL {
		return cacheMinTTL
	}

	if cacheMaxTTL != 0 && ttl > cacheMaxTTL {
		return cacheMaxTTL
	}

	return ttl
}

// msgToKey constructs the cache key from type, class and question's name of m.
func msgToKey(m *dns.Msg) (b []byte) {
	q := m.Question[0]
	name := q.Name
	b = make([]byte, packedMsgLenSz+packedMsgLenSz+len(name))

	// Put QTYPE, QCLASS, and QNAME.
	binary.BigEndian.PutUint16(b, q.Qtype)
	binary.BigEndian.PutUint16(b[packedMsgLenSz:], q.Qclass)
	copy(b[2*packedMsgLenSz:], strings.ToLower(name))

	return b
}

const (
	// keyMaskIndex is the index of the byte with mask ones value.
	keyMaskIndex = 1 + 2*packedMsgLenSz

	// keyIPIndex is the start index of the IP address in the key.
	keyIPIndex = keyMaskIndex + 1
)

// msgToKeyWithSubnet constructs the cache key from DO bit, type, class, subnet
// mask, client's IP address and question's name of m.  ecsIP is expected to be
// masked already.
func msgToKeyWithSubnet(m *dns.Msg, ecsIP net.IP, mask int) (key []byte) {
	q := m.Question[0]
	keyLen := keyIPIndex + len(q.Name)
	masked := mask != 0
	if masked {
		keyLen += len(ecsIP)
	}

	// Initialize the slice.
	key = make([]byte, keyLen)

	// Put DO.
	opt := m.IsEdns0()
	key[0] = mathutil.BoolToNumber[byte](opt != nil && opt.Do())

	// Put Qtype.
	//
	// TODO(d.kolyshev): We should put Qtype in key[1:].
	binary.BigEndian.PutUint16(key[:], q.Qtype)

	// Put Qclass.
	binary.BigEndian.PutUint16(key[1+packedMsgLenSz:], q.Qclass)

	// Add mask.
	key[keyMaskIndex] = uint8(mask)
	k := keyIPIndex
	if masked {
		k += copy(key[keyIPIndex:], ecsIP)
	}

	copy(key[k:], strings.ToLower(q.Name))

	return key
}

// isDNSSEC returns true if r is a DNSSEC RR.  NSEC, NSEC3, DS, DNSKEY and
// RRSIG/SIG are DNSSEC records.
func isDNSSEC(r dns.RR) bool {
	switch r.Header().Rrtype {
	case
		dns.TypeNSEC,
		dns.TypeNSEC3,
		dns.TypeDS,
		dns.TypeRRSIG,
		dns.TypeSIG,
		dns.TypeDNSKEY:
		return true
	default:
		return false
	}
}

// filterRRSlice removes OPT RRs, DNSSEC RRs except the specified type if do is
// false, sets TTL if ttl is not equal to zero and returns the copy of the rrs.
// The except parameter defines RR of which type should not be filtered out.
func filterRRSlice(rrs []dns.RR, do bool, ttl uint32, except uint16) (filtered []dns.RR) {
	rrsLen := len(rrs)
	if rrsLen == 0 {
		return nil
	}

	j := 0
	rs := make([]dns.RR, rrsLen)
	for _, r := range rrs {
		if (!do && isDNSSEC(r) && r.Header().Rrtype != except) || r.Header().Rrtype == dns.TypeOPT {
			continue
		}

		if ttl != 0 {
			r.Header().Ttl = ttl
		}
		rs[j] = dns.Copy(r)
		j++
	}

	return rs[:j]
}

// filterMsg removes OPT RRs, DNSSEC RRs if do is false, sets TTL to ttl if it's
// not equal to 0 and puts the results to appropriate fields of dst.  It also
// filters the AD bit if both ad and do are false.
func filterMsg(dst, m *dns.Msg, ad, do bool, ttl uint32) {
	// As RFC 6840 says, validating resolvers should only set the AD bit when a
	// response both meets the conditions listed in RFC 4035, and the request
	// contained either a set DO bit or a set AD bit.
	dst.AuthenticatedData = dst.AuthenticatedData && (ad || do)

	// It's important to filter out only DNSSEC RRs that aren't explicitly
	// requested.
	//
	// See https://datatracker.ietf.org/doc/html/rfc4035#section-3.2.1 and
	// https://github.com/AdguardTeam/dnsproxy/issues/144.
	dst.Answer = filterRRSlice(m.Answer, do, ttl, m.Question[0].Qtype)
	dst.Ns = filterRRSlice(m.Ns, do, ttl, dns.TypeNone)
	dst.Extra = filterRRSlice(m.Extra, do, ttl, dns.TypeNone)
}
07070100000056000081A4000000000000000000000001679A649F00005EB9000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/cache_test.gopackage proxy

import (
	"context"
	"net"
	"net/netip"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"github.com/miekg/dns"
)

// testCacheSize is the maximum size of cache for tests.
const testCacheSize = 4096

const testUpsAddr = "https://upstream.address"

var upstreamWithAddr = &fakeUpstream{
	onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) { panic("not implemented") },
	onClose:    func() (err error) { panic("not implemented") },
	onAddress:  func() (addr string) { return testUpsAddr },
}

func TestServeCached(t *testing.T) {
	dnsProxy := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		CacheEnabled:           true,
	})

	// Start listening.
	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// Fill the cache.
	reply := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
	}).SetQuestion("google.com.", dns.TypeA)
	reply.SetEdns0(defaultUDPBufSize, false)

	dnsProxy.cache.set(reply, upstreamWithAddr, slogutil.NewDiscardLogger())

	// Create a DNS-over-UDP client connection.
	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{
		Net:     string(ProtoUDP),
		Timeout: testTimeout,
	}

	// Create a DNS request.
	request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
	request.SetEdns0(defaultUDPBufSize, false)

	r, _, err := client.Exchange(request, addr.String())
	require.NoErrorf(t, err, "error in the first request: %s", err)

	requireEqualMsgs(t, r, reply)
}

func TestCache_expired(t *testing.T) {
	const host = "google.com."

	ans := &dns.A{
		Hdr: dns.RR_Header{
			Name:   host,
			Rrtype: dns.TypeA,
			Class:  dns.ClassINET,
		},
		A: net.IP{8, 8, 8, 8},
	}
	reply := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{ans},
	}).SetQuestion(host, dns.TypeA)

	testCases := []struct {
		name       string
		ttl        uint32
		wantTTL    uint32
		optimistic bool
	}{{
		name:       "realistic_hit",
		ttl:        defaultTestTTL,
		wantTTL:    defaultTestTTL,
		optimistic: false,
	}, {
		name:       "realistic_miss",
		ttl:        0,
		wantTTL:    0,
		optimistic: false,
	}, {
		name:       "optimistic_hit",
		ttl:        defaultTestTTL,
		wantTTL:    defaultTestTTL,
		optimistic: true,
	}, {
		name:       "optimistic_expired",
		ttl:        0,
		wantTTL:    optimisticTTL,
		optimistic: true,
	}}

	testCache := newCache(testCacheSize, false, false)
	for _, tc := range testCases {
		ans.Hdr.Ttl = tc.ttl
		req := (&dns.Msg{}).SetQuestion(host, dns.TypeA)

		t.Run(tc.name, func(t *testing.T) {
			if tc.optimistic {
				testCache.optimistic = true
				t.Cleanup(func() { testCache.optimistic = false })
			}

			key := msgToKey(reply)
			data := (&cacheItem{
				m:   reply,
				u:   testUpsAddr,
				ttl: tc.ttl,
			}).pack()
			testCache.items.Set(key, data)
			t.Cleanup(testCache.items.Clear)

			r, expired, key := testCache.get(req)
			assert.Equal(t, msgToKey(req), key)
			assert.Equal(t, tc.ttl == 0, expired)

			if tc.wantTTL != 0 {
				require.NotNil(t, r)

				assert.Equal(t, tc.wantTTL, r.m.Answer[0].Header().Ttl)
				assert.Equal(t, testUpsAddr, r.u)
			} else {
				require.Nil(t, r)
			}
		})
	}
}

func TestCacheDO(t *testing.T) {
	testCache := newCache(testCacheSize, false, false)

	// Fill the cache.
	reply := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
	}).SetQuestion("google.com.", dns.TypeA)
	reply.SetEdns0(4096, true)

	// Store in cache.
	testCache.set(reply, upstreamWithAddr, slogutil.NewDiscardLogger())

	// Make a request.
	request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)

	t.Run("without_do", func(t *testing.T) {
		ci, expired, key := testCache.get(request)
		assert.False(t, expired)
		assert.Equal(t, msgToKey(request), key)
		assert.NotNil(t, ci)
	})

	t.Run("with_do", func(t *testing.T) {
		reqClone := request.Copy()
		t.Cleanup(func() {
			request = reqClone
		})

		request.SetEdns0(4096, true)

		ci, expired, key := testCache.get(request)
		assert.False(t, expired)
		assert.Equal(t, msgToKey(request), key)

		require.NotNil(t, ci)

		assert.Equal(t, testUpsAddr, ci.u)
	})
}

func TestCacheCNAME(t *testing.T) {
	l := slogutil.NewDiscardLogger()

	testCache := newCache(testCacheSize, false, false)

	// Fill the cache
	reply := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{newRR(t, "google.com.", dns.TypeCNAME, 3600, "test.google.com.")},
	}).SetQuestion("google.com.", dns.TypeA)
	testCache.set(reply, upstreamWithAddr, l)

	// Create a DNS request.
	request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)

	t.Run("no_cnames", func(t *testing.T) {
		r, expired, _ := testCache.get(request)
		assert.Nil(t, r)
		assert.False(t, expired)
	})

	// Now fill the cache with a cacheable CNAME response.
	reply.Answer = append(reply.Answer, newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8}))
	testCache.set(reply, upstreamWithAddr, l)

	// We are testing that a proper CNAME response gets cached
	t.Run("cnames_exist", func(t *testing.T) {
		r, expired, key := testCache.get(request)
		assert.False(t, expired)
		assert.Equal(t, key, msgToKey(request))

		require.NotNil(t, r)

		assert.Equal(t, testUpsAddr, r.u)
	})
}

func TestCache_uncacheable(t *testing.T) {
	testCache := newCache(testCacheSize, false, false)

	// Create a DNS request.
	request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
	// Fill the cache.
	reply := (&dns.Msg{}).SetRcode(request, dns.RcodeBadAlg)

	// We are testing that SERVFAIL responses aren't cached
	testCache.set(reply, upstreamWithAddr, slogutil.NewDiscardLogger())

	r, expired, _ := testCache.get(request)
	assert.Nil(t, r)
	assert.False(t, expired)
}

func TestCache_concurrent(t *testing.T) {
	testCache := newCache(testCacheSize, false, false)

	hosts := map[string]string{
		dns.Fqdn("yandex.com"):     "213.180.204.62",
		dns.Fqdn("google.com"):     "8.8.8.8",
		dns.Fqdn("www.google.com"): "8.8.4.4",
		dns.Fqdn("youtube.com"):    "173.194.221.198",
		dns.Fqdn("car.ru"):         "37.220.161.35",
		dns.Fqdn("cat.ru"):         "192.56.231.67",
	}

	g := &sync.WaitGroup{}
	g.Add(len(hosts))

	for k, v := range hosts {
		go setAndGetCache(t, testCache, g, k, v)
	}

	g.Wait()
}

const (
	// cacheTick is a cache check period.
	cacheTick = 100 * time.Millisecond

	// cacheTimeout is the timeout of cache check.
	cacheTimeout = 20 * cacheTick
)

func TestCacheExpiration(t *testing.T) {
	t.Parallel()

	dnsProxy := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		CacheEnabled:           true,
	})

	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	l := slogutil.NewDiscardLogger()

	// Create dns messages with TTL of 1 second.
	rrs := []dns.RR{
		newRR(t, "youtube.com.", dns.TypeA, 1, net.IP{173, 194, 221, 198}),
		newRR(t, "google.com.", dns.TypeA, 1, net.IP{8, 8, 8, 8}),
		newRR(t, "yandex.com.", dns.TypeA, 1, net.IP{213, 180, 204, 62}),
	}
	replies := make([]*dns.Msg, len(rrs))
	for i, rr := range rrs {
		rep := (&dns.Msg{
			MsgHdr: dns.MsgHdr{
				Response: true,
			},
			Answer: []dns.RR{dns.Copy(rr)},
		}).SetQuestion(rr.Header().Name, dns.TypeA)
		dnsProxy.cache.set(rep, upstreamWithAddr, l)
		replies[i] = rep
	}

	for _, r := range replies {
		ci, expired, key := dnsProxy.cache.get(r)
		require.NotNil(t, ci)

		assert.False(t, expired)
		assert.Equal(t, msgToKey(ci.m), key)

		requireEqualMsgs(t, ci.m, r)
	}

	assert.Eventually(t, func() bool {
		for _, r := range replies {
			if ci, _, _ := dnsProxy.cache.get(r); ci != nil {
				return false
			}
		}

		return true
	}, cacheTimeout, cacheTick)
}

func TestCacheExpirationWithTTLOverride(t *testing.T) {
	u := testUpstream{}

	dnsProxy := mustNew(t, &Config{
		Logger:        slogutil.NewDiscardLogger(),
		UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig: &UpstreamConfig{
			Upstreams: []upstream.Upstream{&u},
		},
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		CacheEnabled:           true,
		CacheMinTTL:            20,
		CacheMaxTTL:            40,
	})

	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	d := &DNSContext{}

	t.Run("replace_min", func(t *testing.T) {
		d.Req = newHostTestMessage("host")
		d.Addr = netip.AddrPort{}

		u.ans = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Rrtype: dns.TypeA,
				Name:   "host.",
				Ttl:    10,
			},
			A: net.IP{4, 3, 2, 1},
		}}

		err = dnsProxy.Resolve(d)
		require.NoError(t, err)

		ci, expired, key := dnsProxy.cache.get(d.Req)
		assert.False(t, expired)
		assert.Equal(t, msgToKey(d.Req), key)

		require.NotNil(t, ci)
		assert.Equal(t, dnsProxy.CacheMinTTL, ci.m.Answer[0].Header().Ttl)
	})

	t.Run("replace_max", func(t *testing.T) {
		d.Req = newHostTestMessage("host2")
		d.Addr = netip.AddrPort{}

		u.ans = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Rrtype: dns.TypeA,
				Name:   "host2.",
				Ttl:    60,
			},
			A: net.IP{4, 3, 2, 1},
		}}

		err = dnsProxy.Resolve(d)
		assert.Nil(t, err)

		ci, expired, key := dnsProxy.cache.get(d.Req)
		assert.False(t, expired)
		assert.Equal(t, msgToKey(d.Req), key)

		require.NotNil(t, ci)
		assert.Equal(t, dnsProxy.CacheMaxTTL, ci.m.Answer[0].Header().Ttl)
	})
}

type testEntry struct {
	q string
	a []dns.RR
	t uint16
}

type testCase struct {
	ok require.BoolAssertionFunc
	q  string
	a  []dns.RR
	t  uint16
}

type testCases struct {
	cache []testEntry
	cases []testCase
}

func TestCache(t *testing.T) {
	t.Run("simple", func(t *testing.T) {
		testCases{
			cache: []testEntry{{
				q: "google.com.",
				a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t: dns.TypeA,
			}},
			cases: []testCase{{
				ok: require.True,
				q:  "google.com.",
				a:  []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t:  dns.TypeA,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}},
		}.run(t)
	})

	t.Run("mixed_case", func(t *testing.T) {
		testCases{
			cache: []testEntry{{
				q: "gOOgle.com.",
				a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t: dns.TypeA,
			}},
			cases: []testCase{{
				ok: require.True,
				q:  "gOOgle.com.",
				a:  []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t:  dns.TypeA,
			}, {
				ok: require.True,
				q:  "google.com.",
				a:  []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t:  dns.TypeA,
			}, {
				ok: require.True,
				q:  "GOOGLE.COM.",
				a:  []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t:  dns.TypeA,
			}, {
				q:  "gOOgle.com.",
				t:  dns.TypeMX,
				ok: require.False,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}, {
				ok: require.False,
				q:  "GOOGLE.COM.",
				t:  dns.TypeMX,
			}},
		}.run(t)
	})

	t.Run("zero_ttl", func(t *testing.T) {
		testCases{
			cache: []testEntry{{
				q: "gOOgle.com.",
				a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 0, net.IP{8, 8, 8, 8})},
				t: dns.TypeA,
			}},
			cases: []testCase{{
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeA,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeA,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeA,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}},
		}.run(t)
	})
}

func (tests testCases) run(t *testing.T) {
	l := slogutil.NewDiscardLogger()

	testCache := newCache(testCacheSize, false, false)

	for _, res := range tests.cache {
		reply := (&dns.Msg{
			MsgHdr: dns.MsgHdr{
				Response: true,
			},
			Answer: res.a,
		}).SetQuestion(res.q, res.t)
		testCache.set(reply, upstreamWithAddr, l)
	}

	for _, tc := range tests.cases {
		request := (&dns.Msg{}).SetQuestion(tc.q, tc.t)

		ci, expired, _ := testCache.get(request)
		assert.False(t, expired)
		tc.ok(t, ci != nil)

		if tc.a == nil {
			return
		} else if ci == nil {
			continue
		}

		reply := (&dns.Msg{
			MsgHdr: dns.MsgHdr{
				Response: true,
			},
			Answer: tc.a,
		}).SetQuestion(tc.q, tc.t)

		testCache.set(reply, upstreamWithAddr, l)

		requireEqualMsgs(t, ci.m, reply)
	}
}

// requireEqualMsgs asserts the messages are equal except their ID, Rdlength, and
// the case of questions.
func requireEqualMsgs(t *testing.T, expected, actual *dns.Msg) {
	t.Helper()

	temp := *expected
	temp.Id = actual.Id

	require.Equal(t, len(temp.Answer), len(actual.Answer))
	for i, ans := range actual.Answer {
		temp.Answer[i].Header().Rdlength = ans.Header().Rdlength
	}
	for _, rr := range actual.Answer {
		if a, ok := rr.(*dns.A); ok {
			if a4 := a.A.To4(); a4 != nil {
				a.A = a4
			}
		}
	}
	for i := range temp.Question {
		temp.Question[i].Name = strings.ToLower(temp.Question[i].Name)
	}
	for i := range actual.Question {
		actual.Question[i].Name = strings.ToLower(actual.Question[i].Name)
	}

	assert.Equal(t, &temp, actual)
}

func setAndGetCache(t *testing.T, c *cache, g *sync.WaitGroup, host, ip string) {
	defer g.Done()

	ipAddr := net.ParseIP(ip)

	dnsMsg := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{newRR(t, host, dns.TypeA, 1, ipAddr)},
	}).SetQuestion(host, dns.TypeA)

	c.set(dnsMsg, upstreamWithAddr, slogutil.NewDiscardLogger())

	for range 2 {
		ci, expired, key := c.get(dnsMsg)
		require.NotNilf(t, ci, "no cache found for %s", host)

		assert.False(t, expired)
		assert.Equal(t, msgToKey(dnsMsg), key)

		requireEqualMsgs(t, ci.m, dnsMsg)
	}

	assert.Eventuallyf(t, func() bool {
		ci, _, _ := c.get(dnsMsg)

		return ci == nil
	}, cacheTimeout, cacheTick, "cache for %s should already be removed", host)
}

func TestCache_getWithSubnet(t *testing.T) {
	const testFQDN = "example.com."

	ip1234, ip2234, ip3234 := net.IP{1, 2, 3, 4}, net.IP{2, 2, 3, 4}, net.IP{3, 2, 3, 4}
	req := (&dns.Msg{}).SetQuestion(testFQDN, dns.TypeA)
	mask16 := net.CIDRMask(16, netutil.IPv4BitLen)
	mask24 := net.CIDRMask(24, netutil.IPv4BitLen)
	l := slogutil.NewDiscardLogger()

	c := newCache(testCacheSize, true, false)

	t.Run("empty", func(t *testing.T) {
		ci, expired, _ := c.getWithSubnet(req, &net.IPNet{IP: ip1234, Mask: mask24})
		assert.Nil(t, ci)
		assert.False(t, expired)
	})

	// Add a response with subnet.
	resp := (&dns.Msg{
		Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 1, net.IP{1, 1, 1, 1})},
	}).SetReply(req)
	c.setWithSubnet(resp, upstreamWithAddr, &net.IPNet{IP: ip1234, Mask: mask16}, slogutil.NewDiscardLogger())

	t.Run("different_ip", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip2234, Mask: mask24})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, ip2234, 0), key)
		assert.Nil(t, ci)
	})

	// Add a response entry with subnet #2.
	resp = (&dns.Msg{
		Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 1, net.IP{2, 2, 2, 2})},
	}).SetReply(req)
	c.setWithSubnet(resp, upstreamWithAddr, &net.IPNet{IP: ip2234, Mask: mask16}, l)

	// Add a response entry without subnet.
	resp = (&dns.Msg{
		Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 1, net.IP{3, 3, 3, 3})},
	}).SetReply(req)
	c.setWithSubnet(resp, upstreamWithAddr, &net.IPNet{IP: nil, Mask: nil}, l)

	t.Run("with_subnet_1", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip1234, Mask: mask24})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, ip1234.Mask(mask16), 16), key)

		require.NotNil(t, ci)
		require.NotNil(t, ci.m)
		require.NotEmpty(t, ci.m.Answer)

		a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
		assert.True(t, a.A.Equal(net.IP{1, 1, 1, 1}))
	})

	t.Run("with_subnet_2", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip2234, Mask: mask24})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, ip2234.Mask(mask16), 16), key)

		require.NotNil(t, ci)
		require.NotNil(t, ci.m)
		require.NotEmpty(t, ci.m.Answer)

		a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
		assert.True(t, a.A.Equal(net.IP{2, 2, 2, 2}))
	})

	t.Run("with_subnet_3", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip3234, Mask: mask24})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, ip1234, 0), key)

		require.NotNil(t, ci)
		require.NotNil(t, ci.m)
		require.NotEmpty(t, ci.m.Answer)

		a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
		assert.True(t, a.A.Equal(net.IP{3, 3, 3, 3}))
	})
}

func TestCache_getWithSubnet_mask(t *testing.T) {
	const testFQDN = "example.com."

	testIP := net.IP{176, 112, 191, 0}
	noMatchIP := net.IP{177, 112, 191, 0}

	// cachedIP/cidrMask network contains the testIP.
	const cidrMaskOnes = 20
	cidrMask := net.CIDRMask(cidrMaskOnes, netutil.IPv4BitLen)
	cachedIP := net.IP{176, 112, 176, 0}

	ansIP := net.IP{4, 4, 4, 4}

	c := newCache(testCacheSize, true, true)

	req := (&dns.Msg{}).SetQuestion(testFQDN, dns.TypeA)
	resp := (&dns.Msg{
		Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 300, ansIP)},
	}).SetReply(req)

	// Cache IP network that contains the testIP.
	c.setWithSubnet(
		resp,
		upstreamWithAddr,
		&net.IPNet{IP: cachedIP, Mask: cidrMask},
		slogutil.NewDiscardLogger(),
	)

	t.Run("mask_matched", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{
			IP:   testIP,
			Mask: net.CIDRMask(24, netutil.IPv4BitLen),
		})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, testIP.Mask(cidrMask), cidrMaskOnes), key)

		require.NotNil(t, ci)
		require.NotNil(t, ci.m)
		require.NotEmpty(t, ci.m.Answer)

		a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
		assert.True(t, a.A.Equal(ansIP))
	})

	t.Run("no_mask_matched", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{
			IP:   noMatchIP,
			Mask: net.CIDRMask(24, netutil.IPv4BitLen),
		})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, noMatchIP, 0), key)
		assert.Nil(t, ci)
	})
}

func TestCache_IsCacheable_negative(t *testing.T) {
	const someTTL = 3600

	msgHdr := func(rcode int) (hdr dns.MsgHdr) { return dns.MsgHdr{Id: dns.Id(), Rcode: rcode} }
	aQuestions := func(name string) []dns.Question {
		return []dns.Question{{
			Name:   name,
			Qtype:  dns.TypeA,
			Qclass: dns.ClassINET,
		}}
	}

	cnameAns := func(name, cname string) (rr dns.RR) {
		return &dns.CNAME{
			Hdr: dns.RR_Header{
				Name:   name,
				Rrtype: dns.TypeCNAME,
				Class:  dns.ClassINET,
				Ttl:    someTTL,
			},
			Target: cname,
		}
	}

	soaAns := func(name, ns, mbox string) (rr dns.RR) {
		return &dns.SOA{
			Hdr: dns.RR_Header{
				Name:   name,
				Rrtype: dns.TypeSOA,
				Class:  dns.ClassINET,
				Ttl:    someTTL,
			},
			Ns:   ns,
			Mbox: mbox,
		}
	}

	nsAns := func(name, ns string) (rr dns.RR) {
		return &dns.NS{
			Hdr: dns.RR_Header{
				Name:   name,
				Rrtype: dns.TypeNS,
				Class:  dns.ClassINET,
				Ttl:    someTTL,
			},
			Ns: ns,
		}
	}

	aAns := func(name string, a net.IP) (rr dns.RR) {
		return &dns.A{
			Hdr: dns.RR_Header{
				Name:   name,
				Rrtype: dns.TypeA,
				Class:  dns.ClassINET,
				Ttl:    someTTL,
			},
			A: a,
		}
	}

	const (
		hostname        = "AN.EXAMPLE."
		anotherHostname = "ANOTHER.EXAMPLE."
		cname           = "TRIPPLE.XX."
		mbox            = "HOSTMASTER.NS1.XX."
		ns1, ns2        = "NS1.XX.", "NS2.XX."
		xx              = "XX."
	)

	// See https://datatracker.ietf.org/doc/html/rfc2308.
	testCases := []struct {
		req     *dns.Msg
		name    string
		wantTTL uint32
	}{{
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeNameError),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
			Ns: []dns.RR{
				soaAns(xx, ns1, mbox),
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nxdomain_response_type_1",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeNameError),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
			Ns:       []dns.RR{soaAns("XX.", ns1, mbox)},
		},
		name:    "rfc2308_nxdomain_response_type_2",
		wantTTL: someTTL,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeNameError),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
		},
		name:    "rfc2308_nxdomain_response_type_3",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeNameError),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
			Ns: []dns.RR{
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nxdomain_response_type_4",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
			Ns: []dns.RR{
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nxdomain_referral_response",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(anotherHostname),
			Ns: []dns.RR{
				soaAns(xx, ns1, mbox),
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nodata_response_type_1",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(anotherHostname),
			Ns:       []dns.RR{soaAns(xx, ns1, mbox)},
		},
		name:    "rfc2308_nodata_response_type_2",
		wantTTL: someTTL,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(anotherHostname),
		},
		name:    "rfc2308_nodata_response_type_3",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(anotherHostname),
			Ns: []dns.RR{
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nodata_referral_response",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeServerFailure),
			Question: aQuestions(anotherHostname),
		},
		name:    "servfail_response",
		wantTTL: ServFailMaxCacheTTL,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			assert.Equal(t, tc.wantTTL, cacheTTL(tc.req, slogutil.NewDiscardLogger()))
		})
	}
}
07070100000057000081A4000000000000000000000001679A649F000001F3000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/clock.gopackage proxy

import "time"

// clock is the interface for provider of current time.  It's used to simplify
// testing.
//
// TODO(e.burkov):  Move to golibs.
type clock interface {
	// Now returns the current local time.
	Now() (now time.Time)
}

// type check
var _ clock = realClock{}

// realClock is the [clock] which actually uses the [time] package.
type realClock struct{}

// Now implements the [clock] interface for RealClock.
func (realClock) Now() (now time.Time) { return time.Now() }
07070100000058000081A4000000000000000000000001679A649F000034B5000000000000000000000000000000000000002000000000dnsproxy-0.75.0/proxy/config.gopackage proxy

import (
	"crypto/tls"
	"fmt"
	"log/slog"
	"net"
	"net/netip"
	"net/url"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/ameshkov/dnscrypt/v2"
)

// LogPrefix is a prefix for logging.
const LogPrefix = "dnsproxy"

// RequestHandler is an optional custom handler for DNS requests.  It's used
// instead of [Proxy.Resolve] if set.  The resulting error doesn't affect the
// request processing.  The custom handler is responsible for calling
// [ResponseHandler], if it doesn't call [Proxy.Resolve].
//
// TODO(e.burkov):  Use the same interface-based approach as
// [BeforeRequestHandler].
type RequestHandler func(p *Proxy, dctx *DNSContext) (err error)

// ResponseHandler is an optional custom handler called when DNS query has been
// processed.  When called from [Proxy.Resolve], dctx will contain the response
// message if the upstream or cache succeeded.  err is only not nil if the
// upstream failed to respond.
//
// TODO(e.burkov):  Use the same interface-based approach as
// [BeforeRequestHandler].
type ResponseHandler func(dctx *DNSContext, err error)

// Config contains all the fields necessary for proxy configuration.
//
// TODO(a.garipov): Consider extracting conf blocks for better fieldalignment.
type Config struct {
	// Logger is used as the base logger for the proxy service.  If nil,
	// [slog.Default] with [LogPrefix] is used.
	Logger *slog.Logger

	// TrustedProxies is the trusted list of CIDR networks to detect proxy
	// servers addresses from where the DoH requests should be handled.  The
	// value of nil makes Proxy not trust any address.
	TrustedProxies netutil.SubnetSet

	// PrivateSubnets is the set of private networks.  Client having an address
	// within this set is able to resolve PTR requests for addresses within this
	// set.
	PrivateSubnets netutil.SubnetSet

	// MessageConstructor used to build DNS messages.  If nil, the default
	// constructor will be used.
	MessageConstructor MessageConstructor

	// BeforeRequestHandler is an optional custom handler called before each DNS
	// request is started processing, see [BeforeRequestHandler].  The default
	// no-op implementation is used, if it's nil.
	BeforeRequestHandler BeforeRequestHandler

	// RequestHandler is an optional custom handler for DNS requests.  It's used
	// instead of [Proxy.Resolve] if set.  See [RequestHandler].
	RequestHandler RequestHandler

	// ResponseHandler is an optional custom handler called when DNS query has
	// been processed.  See [ResponseHandler].
	ResponseHandler ResponseHandler

	// UpstreamConfig is a general set of DNS servers to forward requests to.
	UpstreamConfig *UpstreamConfig

	// PrivateRDNSUpstreamConfig is the set of upstream DNS servers for
	// resolving private IP addresses.  All the requests considered private will
	// be resolved via these upstream servers.  Such queries will finish with
	// [upstream.ErrNoUpstream] if it's empty.
	PrivateRDNSUpstreamConfig *UpstreamConfig

	// Fallbacks is a list of fallback resolvers.  Those will be used if the
	// general set fails responding.
	Fallbacks *UpstreamConfig

	// Userinfo is the sole permitted userinfo for the DoH basic authentication.
	// If Userinfo is set, all DoH queries are required to have this basic
	// authentication information.
	Userinfo *url.Userinfo

	// TLSConfig is the TLS configuration.  Required for DNS-over-TLS,
	// DNS-over-HTTP, and DNS-over-QUIC servers.
	TLSConfig *tls.Config

	// DNSCryptResolverCert is the DNSCrypt resolver certificate.  Required for
	// DNSCrypt server.
	DNSCryptResolverCert *dnscrypt.Cert

	// DNSCryptProviderName is the DNSCrypt provider name.  Required for
	// DNSCrypt server.
	DNSCryptProviderName string

	// HTTPSServerName sets the Server header of the HTTPS server responses, if
	// not empty.
	HTTPSServerName string

	// UpstreamMode determines the logic through which upstreams will be used.
	// If not specified the [proxy.UpstreamModeLoadBalance] is used.
	UpstreamMode UpstreamMode

	// UDPListenAddr is the set of UDP addresses to listen for plain
	// DNS-over-UDP requests.
	UDPListenAddr []*net.UDPAddr

	// TCPListenAddr is the set of TCP addresses to listen for plain
	// DNS-over-TCP requests.
	TCPListenAddr []*net.TCPAddr

	// HTTPSListenAddr is the set of TCP addresses to listen for DNS-over-HTTPS
	// requests.
	HTTPSListenAddr []*net.TCPAddr

	// TLSListenAddr is the set of TCP addresses to listen for DNS-over-TLS
	// requests.
	TLSListenAddr []*net.TCPAddr

	// QUICListenAddr is the set of UDP addresses to listen for DNS-over-QUIC
	// requests.
	QUICListenAddr []*net.UDPAddr

	// DNSCryptUDPListenAddr is the set of UDP addresses to listen for DNSCrypt
	// requests.
	DNSCryptUDPListenAddr []*net.UDPAddr

	// DNSCryptTCPListenAddr is the set of TCP addresses to listen for DNSCrypt
	// requests.
	DNSCryptTCPListenAddr []*net.TCPAddr

	// BogusNXDomain is the set of networks used to transform responses into
	// NXDOMAIN ones if they contain at least a single IP address within these
	// networks.  It's similar to dnsmasq's "bogus-nxdomain".
	BogusNXDomain []netip.Prefix

	// DNS64Prefs is the set of NAT64 prefixes used for DNS64 handling.  nil
	// value disables the feature.  An empty value will be interpreted as the
	// default Well-Known Prefix.
	DNS64Prefs []netip.Prefix

	// RatelimitWhitelist is a list of IP addresses excluded from rate limiting.
	RatelimitWhitelist []netip.Addr

	// EDNSAddr is the ECS IP used in request.
	EDNSAddr net.IP

	// TODO(s.chzhen):  Extract ratelimit settings to a separate structure.

	// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
	// rate limiting requests.
	RatelimitSubnetLenIPv4 int

	// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
	// rate limiting requests.
	RatelimitSubnetLenIPv6 int

	// Ratelimit is a maximum number of requests per second from a given IP (0
	// to disable).
	Ratelimit int

	// CacheSizeBytes is the maximum cache size in bytes.
	CacheSizeBytes int

	// CacheMinTTL is the minimum TTL for cached DNS responses in seconds.
	CacheMinTTL uint32

	// CacheMaxTTL is the maximum TTL for cached DNS responses in seconds.
	CacheMaxTTL uint32

	// MaxGoroutines is the maximum number of goroutines processing DNS
	// requests.  Important for mobile users.
	//
	// TODO(a.garipov): Rename this to something like “MaxDNSRequestGoroutines”
	// in a later major version, as it doesn't actually limit all goroutines.
	MaxGoroutines uint

	// The size of the read buffer on the underlying socket.  Larger read
	// buffers can handle larger bursts of requests before packets get dropped.
	UDPBufferSize int

	// FastestPingTimeout is the timeout for waiting the first successful
	// dialing when the UpstreamMode is set to [UpstreamModeFastestAddr].
	// Non-positive value will be replaced with the default one.
	FastestPingTimeout time.Duration

	// RefuseAny makes proxy refuse the requests of type ANY.
	RefuseAny bool

	// HTTP3 enables HTTP/3 support for HTTPS server.
	HTTP3 bool

	// Enable EDNS Client Subnet option DNS requests to the upstream server will
	// contain an OPT record with Client Subnet option.  If the original request
	// already has this option set, we pass it through as is.  Otherwise, we set
	// it ourselves using the client IP with subnet /24 (for IPv4) and /56 (for
	// IPv6).
	//
	// If the upstream server supports ECS, it sets subnet number in the
	// response.  This subnet number along with the client IP and other data is
	// used as a cache key.  Next time, if a client from the same subnet
	// requests this host name, we get the response from cache.  If another
	// client from a different subnet requests this host name, we pass his
	// request to the upstream server.
	//
	// If the upstream server doesn't support ECS (there's no subnet number in
	// response), this response will be cached for all clients.
	//
	// If client IP is private (i.e. not public), we don't add EDNS record into
	// a request.  And so there will be no EDNS record in response either.  We
	// store these responses in general cache (without subnet) so they will
	// never be used for clients with public IP addresses.
	EnableEDNSClientSubnet bool

	// CacheEnabled defines if the response cache should be used.
	CacheEnabled bool

	// CacheOptimistic defines if the optimistic cache mechanism should be used.
	CacheOptimistic bool

	// UseDNS64 enables DNS64 handling.  If true, proxy will translate IPv4
	// answers into IPv6 answers using first of DNS64Prefs.  Note also that PTR
	// requests for addresses within the specified networks are considered
	// private and will be forwarded as PrivateRDNSUpstreamConfig specifies.
	// Those will be responded with NXDOMAIN if UsePrivateRDNS is false.
	UseDNS64 bool

	// UsePrivateRDNS defines if the PTR requests for private IP addresses
	// should be resolved via PrivateRDNSUpstreamConfig.  Note that it requires
	// a valid PrivateRDNSUpstreamConfig with at least a single general upstream
	// server.
	UsePrivateRDNS bool

	// PreferIPv6 tells the proxy to prefer IPv6 addresses when bootstrapping
	// upstreams that use hostnames.
	PreferIPv6 bool
}

// validateConfig verifies that the supplied configuration is valid and returns
// an error if it's not.
//
// TODO(s.chzhen):  Use [validate.Interface] from golibs.
func (p *Proxy) validateConfig() (err error) {
	err = p.UpstreamConfig.validate()
	if err != nil {
		return fmt.Errorf("validating general upstreams: %w", err)
	}

	err = ValidatePrivateConfig(p.PrivateRDNSUpstreamConfig, p.privateNets)
	if err != nil {
		if p.UsePrivateRDNS || errors.Is(err, upstream.ErrNoUpstreams) {
			return fmt.Errorf("validating private RDNS upstreams: %w", err)
		}
	}

	// Allow [Proxy.Fallbacks] to be nil, but not empty.  nil means not to use
	// fallbacks at all.
	err = p.Fallbacks.validate()
	if errors.Is(err, upstream.ErrNoUpstreams) {
		return fmt.Errorf("validating fallbacks: %w", err)
	}

	err = p.validateRatelimit()
	if err != nil {
		return fmt.Errorf("validating ratelimit: %w", err)
	}

	switch p.UpstreamMode {
	case "":
		// Go on.
	case UpstreamModeFastestAddr, UpstreamModeLoadBalance, UpstreamModeParallel:
		// Go on.
	default:
		return fmt.Errorf("bad upstream mode: %q", p.UpstreamMode)
	}

	p.logConfigInfo()

	return nil
}

// validateRatelimit validates ratelimit configuration and returns an error if
// it's invalid.
func (p *Proxy) validateRatelimit() (err error) {
	if p.Ratelimit == 0 {
		return nil
	}

	err = checkInclusion(p.RatelimitSubnetLenIPv4, 0, netutil.IPv4BitLen)
	if err != nil {
		return fmt.Errorf("ratelimit subnet len ipv4 is invalid: %w", err)
	}

	err = checkInclusion(p.RatelimitSubnetLenIPv6, 0, netutil.IPv6BitLen)
	if err != nil {
		return fmt.Errorf("ratelimit subnet len ipv6 is invalid: %w", err)
	}

	return nil
}

// checkInclusion returns an error if a n is not in the inclusive range between
// minN and maxN.
func checkInclusion(n, minN, maxN int) (err error) {
	switch {
	case n < minN:
		return fmt.Errorf("value %d less than min %d", n, minN)
	case n > maxN:
		return fmt.Errorf("value %d greater than max %d", n, maxN)
	}

	return nil
}

// logConfigInfo logs proxy configuration information.
func (p *Proxy) logConfigInfo() {
	if p.CacheMinTTL > 0 || p.CacheMaxTTL > 0 {
		p.logger.Info("cache ttl override is enabled", "min", p.CacheMinTTL, "max", p.CacheMaxTTL)
	}

	if p.Ratelimit > 0 {
		p.logger.Info(
			"ratelimit is enabled",
			"rps",
			p.Ratelimit,
			"ipv4_subnet_mask_len",
			p.RatelimitSubnetLenIPv4,
			"ipv6_subnet_mask_len",
			p.RatelimitSubnetLenIPv6,
		)
	}

	if p.RefuseAny {
		p.logger.Info("server will refuse requests of type any")
	}

	if len(p.BogusNXDomain) > 0 {
		p.logger.Info("bogus-nxdomain ip specified", "prefix_len", len(p.BogusNXDomain))
	}

	if p.UpstreamMode != "" {
		p.logger.Info("upstream mode is set", "mode", p.UpstreamMode)
	}
}

// validateListenAddrs returns an error if the addresses are not configured
// properly.
func (p *Proxy) validateListenAddrs() (err error) {
	if !p.hasListenAddrs() {
		return errors.Error("no listen address specified")
	}

	err = p.validateTLSConfig()
	if err != nil {
		return fmt.Errorf("invalid tls configuration: %w", err)
	}

	if p.DNSCryptResolverCert == nil || p.DNSCryptProviderName == "" {
		if p.DNSCryptTCPListenAddr != nil {
			return errors.Error("cannot create dnscrypt tcp listener without dnscrypt config")
		}

		if p.DNSCryptUDPListenAddr != nil {
			return errors.Error("cannot create dnscrypt udp listener without dnscrypt config")
		}
	}

	return nil
}

// validateTLSConfig returns an error if proxy TLS configuration parameters are
// needed but aren't provided.
func (p *Proxy) validateTLSConfig() (err error) {
	if p.TLSConfig != nil {
		return nil
	}

	if p.TLSListenAddr != nil {
		return errors.Error("tls listener configuration not found")
	}

	if p.HTTPSListenAddr != nil {
		return errors.Error("https listener configuration not found")
	}

	if p.QUICListenAddr != nil {
		return errors.Error("quic listener configuration not found")
	}

	return nil
}

// hasListenAddrs - is there any addresses to listen to?
func (p *Proxy) hasListenAddrs() bool {
	return p.UDPListenAddr != nil ||
		p.TCPListenAddr != nil ||
		p.TLSListenAddr != nil ||
		p.HTTPSListenAddr != nil ||
		p.QUICListenAddr != nil ||
		p.DNSCryptUDPListenAddr != nil ||
		p.DNSCryptTCPListenAddr != nil
}
07070100000059000081A4000000000000000000000001679A649F000000AE000000000000000000000000000000000000002500000000dnsproxy-0.75.0/proxy/constructor.gopackage proxy

import (
	"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
)

// MessageConstructor creates DNS messages.
type MessageConstructor = dnsmsg.MessageConstructor
0707010000005A000081A4000000000000000000000001679A649F000025E5000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/dns64.gopackage proxy

import (
	"fmt"
	"net"
	"net/netip"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

const (
	// maxNAT64PrefixBitLen is the maximum length of a NAT64 prefix in bits.
	// See https://datatracker.ietf.org/doc/html/rfc6147#section-5.2.
	maxNAT64PrefixBitLen = 96

	// NAT64PrefixLength is the length of a NAT64 prefix in bytes.
	NAT64PrefixLength = net.IPv6len - net.IPv4len

	// maxDNS64SynTTL is the maximum TTL for synthesized DNS64 responses with no
	// SOA records in seconds.
	//
	// If the SOA RR was not delivered with the negative response to the AAAA
	// query, then the DNS64 SHOULD use the TTL of the original A RR or 600
	// seconds, whichever is shorter.
	//
	// See https://datatracker.ietf.org/doc/html/rfc6147#section-5.1.7.
	maxDNS64SynTTL uint32 = 600
)

// setupDNS64 initializes DNS64 settings, the NAT64 prefixes in particular.  If
// the DNS64 feature is enabled and no prefixes are configured, the default
// Well-Known Prefix is used, just like Section 5.2 of RFC 6147 prescribes.  Any
// configured set of prefixes discards the default Well-Known prefix unless it
// is specified explicitly.  Each prefix also validated to be a valid IPv6 CIDR
// with a maximum length of 96 bits.  The first specified prefix is then used to
// synthesize AAAA records.
func (p *Proxy) setupDNS64() (err error) {
	if !p.Config.UseDNS64 {
		return nil
	}

	if len(p.Config.DNS64Prefs) == 0 {
		p.dns64Prefs = netutil.SliceSubnetSet{dns64WellKnownPref}

		return nil
	}

	for i, pref := range p.Config.DNS64Prefs {
		if !pref.Addr().Is6() {
			return fmt.Errorf("prefix at index %d: %q is not an IPv6 prefix", i, pref)
		}

		if pref.Bits() > maxNAT64PrefixBitLen {
			return fmt.Errorf("prefix at index %d: %q is too long for DNS64", i, pref)
		}

		p.dns64Prefs = append(p.dns64Prefs, pref.Masked())
	}

	return nil
}

// checkDNS64 checks if DNS64 should be performed.  It returns a DNS64 request
// to resolve or nil if DNS64 is not desired.  It also filters resp to not
// contain any NAT64 excluded addresses in the answer section, if needed.  Both
// req and resp must not be nil.
//
// See https://datatracker.ietf.org/doc/html/rfc6147.
func (p *Proxy) checkDNS64(req, resp *dns.Msg) (dns64Req *dns.Msg) {
	if len(p.dns64Prefs) == 0 {
		return nil
	}

	q := req.Question[0]
	if q.Qtype != dns.TypeAAAA || q.Qclass != dns.ClassINET {
		// DNS64 operation for classes other than IN is undefined, and a DNS64
		// MUST behave as though no DNS64 function is configured.
		return nil
	}

	switch resp.Rcode {
	case dns.RcodeNameError:
		// A result with RCODE=3 (Name Error) is handled according to normal DNS
		// operation (which is normally to return the error to the client).
		return nil
	case dns.RcodeSuccess:
		// If resolver receives an answer with at least one AAAA record
		// containing an address outside any of the excluded range(s), then it
		// by default SHOULD build an answer section for a response including
		// only the AAAA record(s) that do not contain any of the addresses
		// inside the excluded ranges.
		var hasAnswers bool
		if resp.Answer, hasAnswers = p.filterNAT64Answers(resp.Answer); hasAnswers {
			return nil
		}
	default:
		// Any other RCODE is treated as though the RCODE were 0 and the answer
		// section were empty.
	}

	dns64Req = req.Copy()
	dns64Req.Id = dns.Id()
	dns64Req.Question[0].Qtype = dns.TypeA

	return dns64Req
}

// filterNAT64Answers filters out AAAA records that are within one of NAT64
// exclusion prefixes.  hasAnswers is true if the filtered slice contains at
// least a single AAAA answer not within the prefixes or a CNAME.
//
// TODO(e.burkov):  Remove prefs from args when old API is removed.
func (p *Proxy) filterNAT64Answers(rrs []dns.RR) (filtered []dns.RR, hasAnswers bool) {
	filtered = make([]dns.RR, 0, len(rrs))
	for _, ans := range rrs {
		switch ans := ans.(type) {
		case *dns.AAAA:
			addr, err := netutil.IPToAddrNoMapped(ans.AAAA)
			if err != nil {
				p.logger.Error("bad aaaa record", slogutil.KeyError, err)
			} else if p.dns64Prefs.Contains(addr) {
				// Filter the record.
				continue
			} else {
				filtered, hasAnswers = append(filtered, ans), true
			}
		case *dns.CNAME, *dns.DNAME:
			// If the response contains a CNAME or a DNAME, then the CNAME or
			// DNAME chain is followed until the first terminating A or AAAA
			// record is reached.
			//
			// Just treat CNAME and DNAME responses as passable answers since
			// AdGuard Home doesn't follow any of these chains except the
			// dnsrewrite-defined ones.
			filtered, hasAnswers = append(filtered, ans), true
		default:
			filtered = append(filtered, ans)
		}
	}

	return filtered, hasAnswers
}

// synthDNS64 synthesizes a DNS64 response using the original response as a
// basis and modifying it with data from resp.  It returns true if the response
// was actually modified.
func (p *Proxy) synthDNS64(origReq, origResp, resp *dns.Msg) (ok bool) {
	if len(resp.Answer) == 0 {
		// If there is an empty answer, then the DNS64 responds to the original
		// querying client with the answer the DNS64 received to the original
		// (initiator's) query.
		return false
	}

	// The Time to Live (TTL) field is set to the minimum of the TTL of the
	// original A RR and the SOA RR for the queried domain.  If the original
	// response contains no SOA records, the minimum of the TTL of the original
	// A RR and [maxDNS64SynTTL] should be used.  See [maxDNS64SynTTL].
	soaTTL := maxDNS64SynTTL
	for _, rr := range origResp.Ns {
		if hdr := rr.Header(); hdr.Rrtype == dns.TypeSOA && hdr.Name == origReq.Question[0].Name {
			soaTTL = hdr.Ttl

			break
		}
	}

	newAns := make([]dns.RR, 0, len(resp.Answer))
	for _, ans := range resp.Answer {
		rr := p.synthRR(ans, soaTTL)
		if rr == nil {
			// The error should have already been logged.
			return false
		}

		newAns = append(newAns, rr)
	}

	origResp.Answer = newAns
	origResp.Ns = resp.Ns
	origResp.Extra = resp.Extra

	return true
}

// dns64WellKnownPref is the default prefix to use in an algorithmic mapping for
// DNS64.  See https://datatracker.ietf.org/doc/html/rfc6052#section-2.1.
var dns64WellKnownPref = netip.MustParsePrefix("64:ff9b::/96")

// shouldStripDNS64 returns true if DNS64 is enabled and req is a PTR for a
// reversed address within either one of custom DNS64 prefixes or the Well-Known
// one.
//
// The requirement is to match any Pref64::/n used at the site, and not merely
// the locally configured Pref64::/n.  This is because end clients could ask for
// a PTR record matching an address received through a different (site-provided)
// DNS64.
//
// See https://datatracker.ietf.org/doc/html/rfc6147#section-5.3.1.
func (p *Proxy) shouldStripDNS64(req *dns.Msg) (ok bool) {
	if len(p.dns64Prefs) == 0 {
		return false
	}

	q := req.Question[0]
	if q.Qtype != dns.TypePTR {
		return false
	}

	host := q.Name
	ip, err := netutil.IPFromReversedAddr(host)
	if err != nil {
		p.logger.Debug("failed to parse ip from ptr request", slogutil.KeyError, err)

		return false
	}

	switch {
	case p.dns64Prefs.Contains(ip):
		p.logger.Debug("the ip is within dns64 custom prefix set", "ip", ip)
	case dns64WellKnownPref.Contains(ip):
		p.logger.Debug("the ip is within dns64 well-known prefix", "ip", ip)
	default:
		return false
	}

	return true
}

// mapDNS64 maps addr to IPv6 address using configured DNS64 prefix.  addr must
// be a valid IPv4.  It panics, if there are no configured DNS64 prefixes,
// because synthesis should not be performed unless DNS64 function enabled.
//
// TODO(e.burkov):  Remove pref from args when old API is removed.
func (p *Proxy) mapDNS64(addr netip.Addr) (mapped net.IP) {
	// Don't mask the address here since it should have already been masked on
	// initialization stage.
	prefData := p.dns64Prefs[0].Addr().As16()
	addrData := addr.As4()

	mapped = make(net.IP, net.IPv6len)
	copy(mapped[:NAT64PrefixLength], prefData[:])
	copy(mapped[NAT64PrefixLength:], addrData[:])

	return mapped
}

// synthRR synthesizes a DNS64 resource record in compliance with RFC 6147.  If
// rr is not an A record, it's returned as is.  A records are modified to become
// a DNS64-synthesized AAAA records, and the TTL is set according to the
// original TTL of a record and soaTTL.  It returns nil on invalid A records.
func (p *Proxy) synthRR(rr dns.RR, soaTTL uint32) (result dns.RR) {
	aResp, ok := rr.(*dns.A)
	if !ok {
		return rr
	}

	addr, err := netutil.IPToAddr(aResp.A, netutil.AddrFamilyIPv4)
	if err != nil {
		p.logger.Error("bad a record", slogutil.KeyError, err)

		return nil
	}

	aaaa := &dns.AAAA{
		Hdr: dns.RR_Header{
			Name:   aResp.Hdr.Name,
			Rrtype: dns.TypeAAAA,
			Class:  aResp.Hdr.Class,
			Ttl:    min(aResp.Hdr.Ttl, soaTTL),
		},
		AAAA: p.mapDNS64(addr),
	}

	return aaaa
}

// performDNS64 returns the upstream that was used to perform DNS64 request, or
// nil, if the request was not performed.
func (p *Proxy) performDNS64(
	origReq *dns.Msg,
	origResp *dns.Msg,
	upstreams []upstream.Upstream,
) (u upstream.Upstream) {
	if origResp == nil {
		return nil
	}

	dns64Req := p.checkDNS64(origReq, origResp)
	if dns64Req == nil {
		return nil
	}

	host := origReq.Question[0].Name
	p.logger.Debug("received an empty aaaa response, checking dns64", "host", host)

	dns64Resp, u, err := p.exchangeUpstreams(dns64Req, upstreams)
	if err != nil {
		p.logger.Error("dns64 request failed", slogutil.KeyError, err)

		return nil
	}

	if dns64Resp != nil && p.synthDNS64(origReq, origResp, dns64Resp) {
		p.logger.Debug("synthesized aaaa response", "host", host)

		return u
	}

	return nil
}
0707010000005B000081A4000000000000000000000001679A649F00002784000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/dns64_test.gopackage proxy

import (
	"context"
	"net"
	"net/netip"
	"sync"
	"testing"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

const ipv4OnlyFqdn = "ipv4.only."

func TestDNS64Race(t *testing.T) {
	ans := newRR(t, ipv4OnlyFqdn, dns.TypeA, 3600, net.ParseIP("1.2.3.4"))
	ups := &fakeUpstream{
		onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			resp = (&dns.Msg{}).SetReply(req)
			if req.Question[0].Qtype == dns.TypeA {
				resp.Answer = []dns.RR{dns.Copy(ans)}
			}

			return resp, nil
		},
		onAddress: func() (addr string) { return "fake.address" },
		onClose:   func() (err error) { return nil },
	}
	localUps := &fakeUpstream{
		onExchange: func(_ *dns.Msg) (_ *dns.Msg, _ error) { panic("not implemented") },
		onAddress:  func() (addr string) { return "fake.address" },
		onClose:    func() (err error) { return nil },
	}

	dnsProxy := mustNew(t, &Config{
		Logger:         slogutil.NewDiscardLogger(),
		UDPListenAddr:  []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:  []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
		UpstreamConfig: &UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		PrivateRDNSUpstreamConfig: &UpstreamConfig{
			Upstreams: []upstream.Upstream{localUps},
		},
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,

		UseDNS64:       true,
		UsePrivateRDNS: true,
		// Valid NAT-64 prefix for 2001:67c:27e4:15::64 server.
		DNS64Prefs: []netip.Prefix{netip.MustParsePrefix("2001:67c:27e4:1064::/96")},
	})

	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	syncCh := make(chan struct{})

	// Send requests.
	g := &sync.WaitGroup{}
	g.Add(testMessagesCount)

	addr := dnsProxy.Addr(ProtoTCP).String()
	for range testMessagesCount {
		// The [dns.Conn] isn't safe for concurrent use despite the requirements
		// from the [net.Conn] documentation.
		var conn *dns.Conn
		conn, err = dns.Dial("tcp", addr)
		require.NoError(t, err)

		go sendTestAAAAMessageAsync(conn, g, ipv4OnlyFqdn, syncCh)
	}

	close(syncCh)
	g.Wait()
}

func sendTestAAAAMessageAsync(conn *dns.Conn, g *sync.WaitGroup, fqdn string, syncCh chan struct{}) {
	pt := testutil.PanicT{}

	defer g.Done()

	req := (&dns.Msg{}).SetQuestion(fqdn, dns.TypeAAAA)
	<-syncCh

	err := conn.WriteMsg(req)
	require.NoError(pt, err)

	res, err := conn.ReadMsg()
	require.NoError(pt, err)
	require.Equal(pt, res.Rcode, dns.RcodeSuccess)
	require.NotEmpty(pt, res.Answer)

	require.IsType(pt, &dns.AAAA{}, res.Answer[0])
}

// newRR is a helper that creates a new dns.RR with the given name, qtype,
// ttl and value.  It fails the test if the qtype is not supported or the type
// of value doesn't match the qtype.
func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns.RR) {
	t.Helper()

	switch qtype {
	case dns.TypeA:
		rr = &dns.A{A: testutil.RequireTypeAssert[net.IP](t, val)}
	case dns.TypeAAAA:
		rr = &dns.AAAA{AAAA: testutil.RequireTypeAssert[net.IP](t, val)}
	case dns.TypeCNAME:
		rr = &dns.CNAME{Target: testutil.RequireTypeAssert[string](t, val)}
	case dns.TypeSOA:
		rr = &dns.SOA{
			Ns:      "ns." + name,
			Mbox:    "hostmaster." + name,
			Serial:  1,
			Refresh: 1,
			Retry:   1,
			Expire:  1,
			Minttl:  1,
		}
	case dns.TypePTR:
		rr = &dns.PTR{Ptr: testutil.RequireTypeAssert[string](t, val)}
	default:
		t.Fatalf("unsupported qtype: %d", qtype)
	}

	*rr.Header() = dns.RR_Header{
		Name:   name,
		Rrtype: qtype,
		Class:  dns.ClassINET,
		Ttl:    ttl,
	}

	return rr
}

func TestProxy_Resolve_dns64(t *testing.T) {
	const (
		ipv6Domain    = "ipv6.only."
		soaDomain     = "ipv4.soa."
		mappedDomain  = "filterable.ipv6."
		anotherDomain = "another.domain."

		pointedDomain = "local1234.ipv4."
		globDomain    = "real1234.ipv4."
	)

	someIPv4 := net.IP{1, 2, 3, 4}
	someIPv6 := net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
	mappedIPv6 := net.ParseIP("64:ff9b::102:304")

	ptr64Domain, err := netutil.IPToReversedAddr(mappedIPv6)
	require.NoError(t, err)
	ptr64Domain = dns.Fqdn(ptr64Domain)

	ptrGlobDomain, err := netutil.IPToReversedAddr(someIPv4)
	require.NoError(t, err)
	ptrGlobDomain = dns.Fqdn(ptrGlobDomain)

	localCliAddr := netip.MustParseAddrPort("192.168.1.1:1234")

	const (
		sectionAnswer = iota
		sectionAuthority
		sectionAdditional

		sectionsNum
	)

	// answerMap is a convenience alias for describing the upstream response for
	// a given question type.
	type answerMap = map[uint16][sectionsNum][]dns.RR

	pt := testutil.PanicT{}
	newUps := func(answers answerMap) (u upstream.Upstream) {
		return &fakeUpstream{
			onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
				q := req.Question[0]
				require.Contains(pt, answers, q.Qtype)

				answer := answers[q.Qtype]

				resp = (&dns.Msg{}).SetReply(req)
				resp.Answer = answer[sectionAnswer]
				resp.Ns = answer[sectionAuthority]
				resp.Extra = answer[sectionAdditional]

				return resp, nil
			},
			onAddress: func() (addr string) { return "fake.address" },
			onClose:   func() (err error) { return nil },
		}
	}

	localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain)
	localUps := &fakeUpstream{
		onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			require.Equal(pt, req.Question[0].Name, ptr64Domain)
			resp = (&dns.Msg{}).SetReply(req)
			resp.Answer = []dns.RR{localRR}

			return resp, nil
		},
		onAddress: func() (addr string) { return "fake.local.address" },
		onClose:   func() (err error) { return nil },
	}

	testCases := []struct {
		name    string
		qname   string
		upsAns  answerMap
		wantAns []dns.RR
		qtype   uint16
	}{{
		name:  "simple_a",
		qname: ipv4OnlyFqdn,
		upsAns: answerMap{
			dns.TypeA: {
				sectionAnswer: {newRR(t, ipv4OnlyFqdn, dns.TypeA, 3600, someIPv4)},
			},
			dns.TypeAAAA: {},
		},
		wantAns: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Name:   ipv4OnlyFqdn,
				Rrtype: dns.TypeA,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			A: someIPv4,
		}},
		qtype: dns.TypeA,
	}, {
		name:  "simple_aaaa",
		qname: ipv6Domain,
		upsAns: answerMap{
			dns.TypeA: {},
			dns.TypeAAAA: {
				sectionAnswer: {newRR(t, ipv6Domain, dns.TypeAAAA, 3600, someIPv6)},
			},
		},
		wantAns: []dns.RR{&dns.AAAA{
			Hdr: dns.RR_Header{
				Name:   ipv6Domain,
				Rrtype: dns.TypeAAAA,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			AAAA: someIPv6,
		}},
		qtype: dns.TypeAAAA,
	}, {
		name:  "actual_dns64",
		qname: ipv4OnlyFqdn,
		upsAns: answerMap{
			dns.TypeA: {
				sectionAnswer: {newRR(t, ipv4OnlyFqdn, dns.TypeA, 3600, someIPv4)},
			},
			dns.TypeAAAA: {},
		},
		wantAns: []dns.RR{&dns.AAAA{
			Hdr: dns.RR_Header{
				Name:   ipv4OnlyFqdn,
				Rrtype: dns.TypeAAAA,
				Class:  dns.ClassINET,
				Ttl:    maxDNS64SynTTL,
			},
			AAAA: mappedIPv6,
		}},
		qtype: dns.TypeAAAA,
	}, {
		name:  "actual_dns64_soattl",
		qname: soaDomain,
		upsAns: answerMap{
			dns.TypeA: {
				sectionAnswer: {newRR(t, soaDomain, dns.TypeA, 3600, someIPv4)},
			},
			dns.TypeAAAA: {
				sectionAuthority: {newRR(t, soaDomain, dns.TypeSOA, maxDNS64SynTTL+50, nil)},
			},
		},
		wantAns: []dns.RR{&dns.AAAA{
			Hdr: dns.RR_Header{
				Name:   soaDomain,
				Rrtype: dns.TypeAAAA,
				Class:  dns.ClassINET,
				Ttl:    maxDNS64SynTTL + 50,
			},
			AAAA: mappedIPv6,
		}},
		qtype: dns.TypeAAAA,
	}, {
		name:  "filtered",
		qname: mappedDomain,
		upsAns: answerMap{
			dns.TypeA: {},
			dns.TypeAAAA: {
				sectionAnswer: {
					newRR(t, mappedDomain, dns.TypeAAAA, 3600, net.ParseIP("64:ff9b::506:708")),
					newRR(t, mappedDomain, dns.TypeCNAME, 3600, anotherDomain),
				},
			},
		},
		wantAns: []dns.RR{&dns.CNAME{
			Hdr: dns.RR_Header{
				Name:   mappedDomain,
				Rrtype: dns.TypeCNAME,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			Target: anotherDomain,
		}},
		qtype: dns.TypeAAAA,
	}, {
		name:   "ptr",
		qname:  ptr64Domain,
		upsAns: nil,
		wantAns: []dns.RR{&dns.PTR{
			Hdr: dns.RR_Header{
				Name:   ptr64Domain,
				Rrtype: dns.TypePTR,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			Ptr: pointedDomain,
		}},
		qtype: dns.TypePTR,
	}, {
		name:  "ptr_glob",
		qname: ptrGlobDomain,
		upsAns: answerMap{
			dns.TypePTR: {
				sectionAnswer: {newRR(t, ptrGlobDomain, dns.TypePTR, 3600, globDomain)},
			},
		},
		wantAns: []dns.RR{&dns.PTR{
			Hdr: dns.RR_Header{
				Name:   ptrGlobDomain,
				Rrtype: dns.TypePTR,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			Ptr: globDomain,
		}},
		qtype: dns.TypePTR,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			p := mustNew(t, &Config{
				Logger:        slogutil.NewDiscardLogger(),
				UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
				TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
				UpstreamConfig: &UpstreamConfig{
					Upstreams: []upstream.Upstream{newUps(tc.upsAns)},
				},
				PrivateRDNSUpstreamConfig: &UpstreamConfig{
					Upstreams: []upstream.Upstream{localUps},
				},
				TrustedProxies:         defaultTrustedProxies,
				RatelimitSubnetLenIPv4: 24,
				RatelimitSubnetLenIPv6: 64,
				CacheEnabled:           true,

				UseDNS64:       true,
				UsePrivateRDNS: true,
				PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
			})

			ctx := context.Background()
			err = p.Start(ctx)
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) })

			dctx := &DNSContext{
				Req:  (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype),
				Addr: localCliAddr,
			}

			err = p.handleDNSRequest(dctx)
			require.NoError(t, err)

			res := dctx.Res
			require.NotNil(t, res)
			assert.Equal(t, tc.wantAns, res.Answer)
		})
	}
}
0707010000005C000081A4000000000000000000000001679A649F00001DDF000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/dnscontext.gopackage proxy

import (
	"net"
	"net/http"
	"net/netip"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
)

// DNSContext represents a DNS request message context
type DNSContext struct {
	// Conn is the underlying client connection.  It is nil if Proto is
	// ProtoDNSCrypt, ProtoHTTPS, or ProtoQUIC.
	Conn net.Conn

	// QUICConnection is the QUIC session from which we got the query.  For
	// ProtoQUIC only.
	QUICConnection quic.Connection

	// QUICStream is the QUIC stream from which we got the query.  For
	// [ProtoQUIC] only.
	QUICStream quic.Stream

	// Upstream is the upstream that resolved the request.  In case of cached
	// response it's nil.
	Upstream upstream.Upstream

	// DNSCryptResponseWriter - necessary to respond to a DNSCrypt query
	DNSCryptResponseWriter dnscrypt.ResponseWriter

	// HTTPResponseWriter - HTTP response writer (for DoH only)
	HTTPResponseWriter http.ResponseWriter

	// HTTPRequest - HTTP request (for DoH only)
	HTTPRequest *http.Request

	// ReqECS is the EDNS Client Subnet used in the request.
	ReqECS *net.IPNet

	// CustomUpstreamConfig is the upstreams configuration used only for current
	// request.  The Resolve method of Proxy uses it instead of the default
	// servers if it's not nil.
	CustomUpstreamConfig *CustomUpstreamConfig

	// queryStatistics contains the DNS query statistics for both the upstream
	// and fallback DNS servers.
	queryStatistics *QueryStatistics

	// Req is the request message.
	Req *dns.Msg

	// Res is the response message.
	Res *dns.Msg

	// Proto is the DNS protocol of the query.
	Proto Proto

	// RequestedPrivateRDNS is the subnet extracted from the ARPA domain of
	// request's question if it's a PTR, SOA, or NS query for a private IP
	// address.  It can be a single-address subnet as well as a zero-length one.
	RequestedPrivateRDNS netip.Prefix

	// localIP - local IP address (for UDP socket to call udpMakeOOBWithSrc)
	localIP netip.Addr

	// Addr is the address of the client.
	Addr netip.AddrPort

	// DoQVersion is the DoQ protocol version. It can (and should) be read from
	// ALPN, but in the current version we also use the way DNS messages are
	// encoded as a signal.
	DoQVersion DoQVersion

	// RequestID is an opaque numerical identifier of this request that is
	// guaranteed to be unique across requests processed by a single Proxy
	// instance.
	RequestID uint64

	// udpSize is the UDP buffer size from request's EDNS0 RR if presented,
	// or default otherwise.
	udpSize uint16

	// IsPrivateClient is true if the client's address is considered private
	// according to the configured private subnet set.
	IsPrivateClient bool

	// adBit is the authenticated data flag from the request.
	adBit bool

	// hasEDNS0 reflects if the request has EDNS0 RRs.
	hasEDNS0 bool

	// doBit is the DNSSEC OK flag from request's EDNS0 RR if presented.
	doBit bool
}

// newDNSContext returns a new properly initialized *DNSContext.
//
// TODO(e.burkov):  Consider creating DNSContext with this everywhere, to
// actually respect the contract of DNSContext.RequestID field.
func (p *Proxy) newDNSContext(proto Proto, req *dns.Msg, addr netip.AddrPort) (d *DNSContext) {
	return &DNSContext{
		Proto: proto,
		Req:   req,
		Addr:  addr,

		RequestID: p.counter.Add(1),
	}
}

// QueryStatistics returns the DNS query statistics for both the upstream and
// fallback DNS servers.  The returned statistics will be nil until a DNS lookup
// has been performed.
//
// Depending on whether the DNS request was successfully resolved and the
// upstream mode, the returned statistics consist of:
//
//   - If the query was successfully resolved, the statistics contain the DNS
//     lookup duration for the main resolver.
//
//   - If the query was retrieved from the cache, the statistics will contain a
//     single entry of [UpstreamStatistics] where the property IsCached is set
//     to true.
//
//   - If the upstream mode is [UpstreamModeFastestAddr] and the query was
//     successfully resolved, the statistics contain the DNS lookup durations or
//     errors for each main upstream.
//
//   - If the query was resolved by the fallback resolver, the statistics
//     contain the DNS lookup errors for each main upstream and the query
//     duration for the fallback resolver.
//
//   - If the query was not resolved at all, the statistics contain the DNS
//     lookup errors for each main and fallback resolvers.
func (dctx *DNSContext) QueryStatistics() (s *QueryStatistics) {
	return dctx.queryStatistics
}

// calcFlagsAndSize lazily calculates some values required for Resolve method.
func (dctx *DNSContext) calcFlagsAndSize() {
	if dctx.udpSize != 0 || dctx.Req == nil {
		return
	}

	dctx.adBit = dctx.Req.AuthenticatedData
	dctx.udpSize = defaultUDPBufSize
	if o := dctx.Req.IsEdns0(); o != nil {
		dctx.hasEDNS0 = true
		dctx.doBit = o.Do()
		dctx.udpSize = o.UDPSize()
	}
}

// scrub prepares the d.Res to be written.  Truncation is applied as well if
// necessary.
func (dctx *DNSContext) scrub() {
	if dctx.Res == nil || dctx.Req == nil {
		return
	}

	// We should guarantee that all the values we need are calculated.
	dctx.calcFlagsAndSize()

	// RFC-6891 (https://tools.ietf.org/html/rfc6891) states that response
	// mustn't contain an EDNS0 RR if the request doesn't include it.
	//
	// See https://github.com/AdguardTeam/dnsproxy/issues/132.
	if dctx.hasEDNS0 && dctx.Res.IsEdns0() == nil {
		dctx.Res.SetEdns0(dctx.udpSize, dctx.doBit)
	}

	dctx.Res.Truncate(int(dnsSize(dctx.Proto == ProtoUDP, dctx.Req)))
	// Some devices require DNS message compression.
	dctx.Res.Compress = true
}

// dnsSize returns the buffer size advertised in the requests OPT record.  When
// the request is over TCP, it returns the maximum allowed size of 64KiB.
func dnsSize(isUDP bool, r *dns.Msg) (size uint16) {
	if !isUDP {
		return dns.MaxMsgSize
	}

	var size16 uint16
	if o := r.IsEdns0(); o != nil {
		size16 = o.UDPSize()
	}

	return max(dns.MinMsgSize, size16)
}

// DoQVersion is an enumeration with supported DoQ versions.
type DoQVersion int

const (
	// DoQv1Draft represents old DoQ draft versions that do not send a 2-octet
	// prefix with the DNS message length.
	//
	// TODO(ameshkov): remove in the end of 2024.
	DoQv1Draft DoQVersion = 0x00

	// DoQv1 represents DoQ v1.0: https://www.rfc-editor.org/rfc/rfc9250.html.
	DoQv1 DoQVersion = 0x01
)

// CustomUpstreamConfig contains upstreams configuration with an optional cache.
type CustomUpstreamConfig struct {
	// upstream is the upstream configuration.
	upstream *UpstreamConfig

	// cache is an optional cache for upstreams in the current configuration.
	// It is disabled if nil.
	//
	// TODO(d.kolyshev): Move this cache to [UpstreamConfig].
	cache *cache
}

// NewCustomUpstreamConfig returns new custom upstream configuration.
func NewCustomUpstreamConfig(
	u *UpstreamConfig,
	cacheEnabled bool,
	cacheSize int,
	enableEDNSClientSubnet bool,
) (c *CustomUpstreamConfig) {
	var customCache *cache
	if cacheEnabled {
		// TODO(d.kolyshev): Support optimistic with newOptimisticResolver.
		customCache = newCache(cacheSize, enableEDNSClientSubnet, false)
	}

	return &CustomUpstreamConfig{
		upstream: u,
		cache:    customCache,
	}
}

// Close closes the custom upstream config.
func (c *CustomUpstreamConfig) Close() (err error) {
	if c.upstream == nil {
		return nil
	}

	return c.upstream.Close()
}

// ClearCache removes all items from the cache.
func (c *CustomUpstreamConfig) ClearCache() {
	if c.cache == nil {
		return
	}

	c.cache.clearItems()
	c.cache.clearItemsWithSubnet()
}
0707010000005D000081A4000000000000000000000001679A649F00000219000000000000000000000000000000000000002000000000dnsproxy-0.75.0/proxy/errors.go//go:build !plan9
// +build !plan9

package proxy

import (
	"syscall"

	"github.com/AdguardTeam/golibs/errors"
)

// isEPIPE checks if the underlying error is EPIPE.  syscall.EPIPE exists on all
// OSes except for Plan 9.  Validate with:
//
//	$ for os in $(go tool dist list | cut -d / -f 1 | sort -u)
//	do
//	        echo -n "$os"
//	        env GOOS="$os" go doc syscall.EPIPE | grep -F -e EPIPE
//	done
//
// For the Plan 9 version see ./errors_plan9.go.
func isEPIPE(err error) (ok bool) {
	return errors.Is(err, syscall.EPIPE)
}
0707010000005E000081A4000000000000000000000001679A649F00000232000000000000000000000000000000000000002600000000dnsproxy-0.75.0/proxy/errors_plan9.go//go:build plan9
// +build plan9

package proxy

import "strings"

// isEPIPE checks if the underlying error is EPIPE.  Plan 9 relies on error
// strings instead of error codes.  I couldn't find the exact constant with the
// text returned by a write on a closed socket, but it seems to be "sys: write
// on closed pipe".  See Plan 9's "man 2 notify".
//
// We don't currently support Plan 9, so it's not critical, but when we do, this
// needs to be rechecked.
func isEPIPE(err error) (ok bool) {
	return strings.Contains(err.Error(), "write on closed pipe")
}
0707010000005F000081A4000000000000000000000001679A649F000002D1000000000000000000000000000000000000002500000000dnsproxy-0.75.0/proxy/errors_test.go//go:build !plan9
// +build !plan9

package proxy

import (
	"fmt"
	"syscall"
	"testing"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/stretchr/testify/assert"
)

func TestIsEPIPE(t *testing.T) {
	type testCase struct {
		err  error
		name string
		want bool
	}

	testCases := []testCase{{
		name: "nil",
		err:  nil,
		want: false,
	}, {
		name: "epipe",
		err:  syscall.EPIPE,
		want: true,
	}, {
		name: "not_epipe",
		err:  errors.Error("test error"),
		want: false,
	}, {
		name: "wrapped_epipe",
		err:  fmt.Errorf("test error: %w", syscall.EPIPE),
		want: true,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			got := isEPIPE(tc.err)
			assert.Equal(t, tc.want, got)
		})
	}
}
07070100000060000081A4000000000000000000000001679A649F00001011000000000000000000000000000000000000002200000000dnsproxy-0.75.0/proxy/exchange.gopackage proxy

import (
	"fmt"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
	"gonum.org/v1/gonum/stat/sampleuv"
)

// exchangeUpstreams resolves req using the given upstreams.  It returns the DNS
// response, the upstream that successfully resolved the request, and the error
// if any.
func (p *Proxy) exchangeUpstreams(
	req *dns.Msg,
	ups []upstream.Upstream,
) (resp *dns.Msg, u upstream.Upstream, err error) {
	switch p.UpstreamMode {
	case UpstreamModeParallel:
		return upstream.ExchangeParallel(ups, req)
	case UpstreamModeFastestAddr:
		switch req.Question[0].Qtype {
		case dns.TypeA, dns.TypeAAAA:
			return p.fastestAddr.ExchangeFastest(req, ups)
		default:
			// Go on to the load-balancing mode.
		}
	default:
		// Go on to the load-balancing mode.
	}

	if len(ups) == 1 {
		u = ups[0]
		resp, _, err = p.exchange(u, req, p.time)
		if err != nil {
			return nil, nil, err
		}

		// TODO(e.burkov):  Consider updating the RTT of a single upstream.

		return resp, u, err
	}

	w := sampleuv.NewWeighted(p.calcWeights(ups), p.randSrc)
	var errs []error
	for i, ok := w.Take(); ok; i, ok = w.Take() {
		u = ups[i]

		var elapsed time.Duration
		resp, elapsed, err = p.exchange(u, req, p.time)
		if err == nil {
			p.updateRTT(u.Address(), elapsed)

			return resp, u, nil
		}

		errs = append(errs, err)

		// TODO(e.burkov):  Use the actual configured timeout or, perhaps, the
		// actual measured elapsed time.
		p.updateRTT(u.Address(), defaultTimeout)
	}

	err = fmt.Errorf("all upstreams failed to exchange request: %w", errors.Join(errs...))

	return nil, nil, err
}

// exchange returns the result of the DNS request exchange with the given
// upstream and the elapsed time in milliseconds.  It uses the given clock to
// measure the request duration.
func (p *Proxy) exchange(
	u upstream.Upstream,
	req *dns.Msg,
	c clock,
) (resp *dns.Msg, dur time.Duration, err error) {
	startTime := c.Now()
	resp, err = u.Exchange(req)

	// Don't use [time.Since] because it uses [time.Now].
	dur = c.Now().Sub(startTime)

	addr := u.Address()
	q := &req.Question[0]
	if err != nil {
		p.logger.Error(
			"exchange failed",
			"upstream", addr,
			"question", q,
			"duration", dur,
			slogutil.KeyError, err,
		)
	} else {
		p.logger.Debug(
			"exchange successfully finished",
			"upstream", addr,
			"question", q,
			"duration", dur,
		)
	}

	return resp, dur, err
}

// upstreamRTTStats is the statistics for a single upstream's round-trip time.
type upstreamRTTStats struct {
	// rttSum is the sum of all the round-trip times in microseconds.  The
	// float64 type is used since it's capable of representing about 285 years
	// in microseconds.
	rttSum float64

	// reqNum is the number of requests to the upstream.  The float64 type is
	// used since to avoid unnecessary type conversions.
	reqNum float64
}

// update returns updated stats after adding given RTT.
func (stats upstreamRTTStats) update(rtt time.Duration) (updated upstreamRTTStats) {
	return upstreamRTTStats{
		rttSum: stats.rttSum + float64(rtt.Microseconds()),
		reqNum: stats.reqNum + 1,
	}
}

// calcWeights returns the slice of weights, each corresponding to the upstream
// with the same index in the given slice.
func (p *Proxy) calcWeights(ups []upstream.Upstream) (weights []float64) {
	weights = make([]float64, 0, len(ups))

	p.rttLock.Lock()
	defer p.rttLock.Unlock()

	for _, u := range ups {
		stat := p.upstreamRTTStats[u.Address()]
		if stat.rttSum == 0 || stat.reqNum == 0 {
			// Use 1 as the default weight.
			weights = append(weights, 1)
		} else {
			weights = append(weights, 1/(stat.rttSum/stat.reqNum))
		}
	}

	return weights
}

// updateRTT updates the round-trip time in [upstreamRTTStats] for given
// address.
func (p *Proxy) updateRTT(address string, rtt time.Duration) {
	p.rttLock.Lock()
	defer p.rttLock.Unlock()

	if p.upstreamRTTStats == nil {
		p.upstreamRTTStats = map[string]upstreamRTTStats{}
	}

	p.upstreamRTTStats[address] = p.upstreamRTTStats[address].update(rtt)
}
07070100000061000081A4000000000000000000000001679A649F00001A0F000000000000000000000000000000000000003000000000dnsproxy-0.75.0/proxy/exchange_internal_test.gopackage proxy

import (
	"net"
	"net/netip"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"golang.org/x/exp/rand"
)

// fakeClock is the function-based implementation of the [clock] interface.
type fakeClock struct {
	onNow func() (now time.Time)
}

// type check
var _ clock = (*fakeClock)(nil)

// Now implements the [clock] interface for *fakeClock.
func (c *fakeClock) Now() (now time.Time) { return c.onNow() }

// newUpstreamWithErrorRate returns an [upstream.Upstream] that responds with an
// error every [rate] requests.  The returned upstream isn't safe for concurrent
// use.
func newUpstreamWithErrorRate(rate uint, name string) (u upstream.Upstream) {
	var n uint

	return &fakeUpstream{
		onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			n++
			if n%rate == 0 {
				return nil, assert.AnError
			}

			return (&dns.Msg{}).SetReply(req), nil
		},
		onAddress: func() (addr string) { return name },
		onClose:   func() (_ error) { panic("not implemented") },
	}
}

// measuredUpstream is an [upstream.Upstream] that increments the counter every
// time it's used.
type measuredUpstream struct {
	// Upstream is embedded here to avoid implementing all the methods.
	upstream.Upstream

	// stats is the statistics collector for current upstream.
	stats map[string]int64
}

// type check
var _ upstream.Upstream = measuredUpstream{}

// Exchange implements the [upstream.Upstream] interface for measuredUpstream.
func (u measuredUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	u.stats[u.Address()]++

	return u.Upstream.Exchange(req)
}

func TestProxy_Exchange_loadBalance(t *testing.T) {
	// Make the test deterministic.
	randSrc := rand.NewSource(42)

	const (
		testRTT     = 1 * time.Second
		requestsNum = 10_000
	)

	// zeroingClock returns the value of currentNow and sets it back to
	// zeroTime, so that all the calls since the second one return the same zero
	// value until currentNow is modified elsewhere.
	zeroTime := time.Unix(0, 0)
	currentNow := zeroTime
	zeroingClock := &fakeClock{
		onNow: func() (now time.Time) {
			now, currentNow = currentNow, zeroTime

			return now
		},
	}
	constClock := &fakeClock{
		onNow: func() (now time.Time) {
			now, currentNow = currentNow, currentNow.Add(testRTT/50)

			return now
		},
	}

	fastUps := &fakeUpstream{
		onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			currentNow = zeroTime.Add(testRTT / 100)

			return (&dns.Msg{}).SetReply(req), nil
		},
		onAddress: func() (addr string) { return "fast" },
		onClose:   func() (_ error) { panic("not implemented") },
	}
	slowerUps := &fakeUpstream{
		onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			currentNow = zeroTime.Add(testRTT / 10)

			return (&dns.Msg{}).SetReply(req), nil
		},
		onAddress: func() (addr string) { return "slower" },
		onClose:   func() (_ error) { panic("not implemented") },
	}
	slowestUps := &fakeUpstream{
		onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			currentNow = zeroTime.Add(testRTT / 2)

			return (&dns.Msg{}).SetReply(req), nil
		},
		onAddress: func() (addr string) { return "slowest" },
		onClose:   func() (_ error) { panic("not implemented") },
	}

	err1Ups := &fakeUpstream{
		onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError },
		onAddress:  func() (addr string) { return "error1" },
		onClose:    func() (_ error) { panic("not implemented") },
	}
	err2Ups := &fakeUpstream{
		onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError },
		onAddress:  func() (addr string) { return "error2" },
		onClose:    func() (_ error) { panic("not implemented") },
	}

	singleError := &sync.Once{}
	// fastestUps responds with an error on the first request.
	fastestUps := &fakeUpstream{
		onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			singleError.Do(func() { err = assert.AnError })
			currentNow = zeroTime.Add(testRTT / 200)

			return (&dns.Msg{}).SetReply(req), err
		},
		onAddress: func() (addr string) { return "fastest" },
		onClose:   func() (_ error) { panic("not implemented") },
	}

	each200 := newUpstreamWithErrorRate(200, "each_200")
	each100 := newUpstreamWithErrorRate(100, "each_100")
	each50 := newUpstreamWithErrorRate(50, "each_50")

	testCases := []struct {
		wantStat map[string]int64
		clock    clock
		name     string
		servers  []upstream.Upstream
	}{{
		wantStat: map[string]int64{
			fastUps.Address():    8917,
			slowerUps.Address():  911,
			slowestUps.Address(): 172,
		},
		clock:   zeroingClock,
		name:    "all_good",
		servers: []upstream.Upstream{slowestUps, slowerUps, fastUps},
	}, {
		wantStat: map[string]int64{
			fastUps.Address():   9081,
			slowerUps.Address(): 919,
			err1Ups.Address():   7,
		},
		clock:   zeroingClock,
		name:    "one_bad",
		servers: []upstream.Upstream{fastUps, err1Ups, slowerUps},
	}, {
		wantStat: map[string]int64{
			err1Ups.Address(): requestsNum,
			err2Ups.Address(): requestsNum,
		},
		clock:   zeroingClock,
		name:    "all_bad",
		servers: []upstream.Upstream{err2Ups, err1Ups},
	}, {
		wantStat: map[string]int64{
			fastUps.Address():    7803,
			slowerUps.Address():  833,
			fastestUps.Address(): 1365,
		},
		clock:   zeroingClock,
		name:    "error_once",
		servers: []upstream.Upstream{fastUps, slowerUps, fastestUps},
	}, {
		wantStat: map[string]int64{
			each200.Address(): 5316,
			each100.Address(): 3090,
			each50.Address():  1683,
		},
		clock:   constClock,
		name:    "error_each_nth",
		servers: []upstream.Upstream{each200, each100, each50},
	}}

	req := newTestMessage()
	cli := netip.AddrPortFrom(netutil.IPv4Localhost(), 1234)

	for _, tc := range testCases {
		ups := []upstream.Upstream{}
		stats := map[string]int64{}
		for _, s := range tc.servers {
			ups = append(ups, measuredUpstream{
				Upstream: s,
				stats:    stats,
			})
		}

		p := mustNew(t, &Config{
			Logger:        slogutil.NewDiscardLogger(),
			UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
			TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
			UpstreamConfig: &UpstreamConfig{
				Upstreams: ups,
			},
			TrustedProxies:         defaultTrustedProxies,
			RatelimitSubnetLenIPv4: 24,
			RatelimitSubnetLenIPv6: 64,
		})
		p.time = tc.clock
		p.randSrc = randSrc

		wantStat := tc.wantStat

		t.Run(tc.name, func(t *testing.T) {
			for range requestsNum {
				_ = p.Resolve(&DNSContext{Req: req, Addr: cli})
			}

			assert.Equal(t, wantStat, stats)
		})
	}
}
07070100000062000081A4000000000000000000000001679A649F000007D8000000000000000000000000000000000000002600000000dnsproxy-0.75.0/proxy/handler_test.gopackage proxy

import (
	"context"
	"net"
	"sync"
	"testing"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestFilteringHandler(t *testing.T) {
	// Initializing the test middleware
	m := &sync.RWMutex{}
	blockResponse := false

	// Prepare the proxy server
	dnsProxy := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		RequestHandler: func(p *Proxy, d *DNSContext) error {
			m.Lock()
			defer m.Unlock()

			if !blockResponse {
				// Use the default Resolve method if response is not blocked
				return p.Resolve(d)
			}

			resp := dns.Msg{}
			resp.SetRcode(d.Req, dns.RcodeNotImplemented)
			resp.RecursionAvailable = true

			// Set the response right away
			d.Res = &resp
			return nil
		},
	})

	// Start listening
	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{
		Net:     string(ProtoUDP),
		Timeout: testTimeout,
	}

	// Send the first message (not blocked)
	req := newTestMessage()

	r, _, err := client.Exchange(req, addr.String())
	require.NoError(t, err)
	requireResponse(t, req, r)

	// Now send the second and make sure it is blocked
	m.Lock()
	blockResponse = true
	m.Unlock()

	r, _, err = client.Exchange(req, addr.String())
	require.NoError(t, err)
	assert.Equal(t, dns.RcodeNotImplemented, r.Rcode)
}
07070100000063000081A4000000000000000000000001679A649F00000907000000000000000000000000000000000000002100000000dnsproxy-0.75.0/proxy/helpers.gopackage proxy

import (
	"net"

	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

// ecsFromMsg returns the subnet from EDNS Client Subnet option of m if any.
func ecsFromMsg(m *dns.Msg) (subnet *net.IPNet, scope int) {
	opt := m.IsEdns0()
	if opt == nil {
		return nil, 0
	}

	var ip net.IP
	var mask net.IPMask
	for _, e := range opt.Option {
		sn, ok := e.(*dns.EDNS0_SUBNET)
		if !ok {
			continue
		}

		switch sn.Family {
		case 1:
			ip = sn.Address.To4()
			mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv4BitLen)
		case 2:
			ip = sn.Address
			mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv6BitLen)
		default:
			continue
		}

		return &net.IPNet{IP: ip, Mask: mask}, int(sn.SourceScope)
	}

	return nil, 0
}

// setECS sets the EDNS client subnet option based on ip and scope into m.  It
// returns masked IP and mask length.
func setECS(m *dns.Msg, ip net.IP, scope uint8) (subnet *net.IPNet) {
	const (
		// defaultECSv4 is the default length of network mask for IPv4 address
		// in ECS option.
		defaultECSv4 = 24

		// defaultECSv6 is the default length of network mask for IPv6 address
		// in ECS.  The size of 7 octets is chosen as a reasonable minimum since
		// at least Google's public DNS refuses requests containing the options
		// with longer network masks.
		defaultECSv6 = 56
	)

	e := &dns.EDNS0_SUBNET{
		Code:        dns.EDNS0SUBNET,
		SourceScope: scope,
	}

	subnet = &net.IPNet{}
	if ip4 := ip.To4(); ip4 != nil {
		e.Family = 1
		e.SourceNetmask = defaultECSv4
		subnet.Mask = net.CIDRMask(defaultECSv4, netutil.IPv4BitLen)
		ip = ip4
	} else {
		// Assume the IP address has already been validated.
		e.Family = 2
		e.SourceNetmask = defaultECSv6
		subnet.Mask = net.CIDRMask(defaultECSv6, netutil.IPv6BitLen)
	}
	subnet.IP = ip.Mask(subnet.Mask)
	e.Address = subnet.IP

	// If OPT record already exists so just add EDNS option inside it.  Note
	// that servers may return FORMERR if they meet several OPT RRs.
	if opt := m.IsEdns0(); opt != nil {
		opt.Option = append(opt.Option, e)

		return subnet
	}

	// Create an OPT record and add EDNS option inside it.
	o := &dns.OPT{
		Hdr: dns.RR_Header{
			Name:   ".",
			Rrtype: dns.TypeOPT,
		},
		Option: []dns.EDNS0{e},
	}
	o.SetUDPSize(4096)
	m.Extra = append(m.Extra, o)

	return subnet
}
07070100000064000081A4000000000000000000000001679A649F0000095C000000000000000000000000000000000000002000000000dnsproxy-0.75.0/proxy/lookup.gopackage proxy

import (
	"context"
	"net/netip"
	"slices"

	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

// helper struct to pass results of lookupIPAddr function
type lookupResult struct {
	resp *dns.Msg
	err  error
}

// lookupIPAddr resolves the specified host IP addresses.  It is intended to be
// used as a goroutine.
func (p *Proxy) lookupIPAddr(
	ctx context.Context,
	host string,
	qtype uint16,
	ch chan *lookupResult,
) {
	defer slogutil.RecoverAndLog(ctx, p.logger)

	req := (&dns.Msg{}).SetQuestion(host, qtype)

	// TODO(d.kolyshev): Investigate why the client address is not defined.
	d := p.newDNSContext(ProtoUDP, req, netip.AddrPort{})
	err := p.Resolve(d)
	ch <- &lookupResult{
		resp: d.Res,
		err:  err,
	}
}

// ErrEmptyHost is returned by LookupIPAddr when the host is empty and can't be
// resolved.
const ErrEmptyHost = errors.Error("host is empty")

// type check
var _ upstream.Resolver = (*Proxy)(nil)

// LookupNetIP implements the [upstream.Resolver] interface for *Proxy.  It
// resolves the specified host IP addresses by sending two DNS queries (A and
// AAAA) in parallel.  It returns both results for those two queries.
func (p *Proxy) LookupNetIP(
	ctx context.Context,
	_ string,
	host string,
) (addrs []netip.Addr, err error) {
	if host == "" {
		return nil, ErrEmptyHost
	}

	host = dns.Fqdn(host)

	ch := make(chan *lookupResult)
	go p.lookupIPAddr(ctx, host, dns.TypeA, ch)
	go p.lookupIPAddr(ctx, host, dns.TypeAAAA, ch)

	var errs []error
	for range 2 {
		result := <-ch
		if result.err != nil {
			errs = append(errs, result.err)

			continue
		}

		addrs = appendAnswerAddrs(addrs, result.resp.Answer)
	}

	if len(addrs) == 0 && len(errs) != 0 {
		return addrs, errors.Join(errs...)
	}

	if p.Config.PreferIPv6 {
		slices.SortStableFunc(addrs, netutil.PreferIPv6)
	} else {
		slices.SortStableFunc(addrs, netutil.PreferIPv4)
	}

	return addrs, nil
}

// appendAnswerAddrs returns addrs with addresses appended from the given ans.
func appendAnswerAddrs(addrs []netip.Addr, ans []dns.RR) (res []netip.Addr) {
	for _, ansRR := range ans {
		a := proxyutil.IPFromRR(ansRR)
		if a != (netip.Addr{}) {
			addrs = append(addrs, a)
		}
	}

	return addrs
}
07070100000065000081A4000000000000000000000001679A649F0000046F000000000000000000000000000000000000002500000000dnsproxy-0.75.0/proxy/lookup_test.gopackage proxy

import (
	"context"
	"net/netip"
	"testing"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestLookupNetIP(t *testing.T) {
	// Use AdGuard DNS here.
	dnsUpstream, err := upstream.AddressToUpstream(
		"94.140.14.14",
		&upstream.Options{
			Logger:  slogutil.NewDiscardLogger(),
			Timeout: defaultTimeout,
		},
	)
	require.NoError(t, err)

	conf := &Config{
		Logger: slogutil.NewDiscardLogger(),
		UpstreamConfig: &UpstreamConfig{
			Upstreams: []upstream.Upstream{dnsUpstream},
		},
	}

	p, err := New(conf)
	require.NoError(t, err)

	// Now let's try doing some lookups.
	addrs, err := p.LookupNetIP(context.Background(), "", "dns.google")
	require.NoError(t, err)
	require.NotEmpty(t, addrs)

	assert.Contains(t, addrs, netip.MustParseAddr("8.8.8.8"))
	assert.Contains(t, addrs, netip.MustParseAddr("8.8.4.4"))
	if len(addrs) > 2 {
		assert.Contains(t, addrs, netip.MustParseAddr("2001:4860:4860::8888"))
		assert.Contains(t, addrs, netip.MustParseAddr("2001:4860:4860::8844"))
	}
}
07070100000066000081A4000000000000000000000001679A649F00000704000000000000000000000000000000000000002C00000000dnsproxy-0.75.0/proxy/optimisticresolver.gopackage proxy

import (
	"context"
	"encoding/hex"
	"log/slog"
	"sync"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
)

// cachingResolver is the DNS resolver that is also able to cache responses.
type cachingResolver interface {
	// replyFromUpstream returns true if the request from dctx is successfully
	// resolved and the response may be cached.
	//
	// TODO(e.burkov):  Find out when ok can be false with nil err.
	replyFromUpstream(dctx *DNSContext) (ok bool, err error)

	// cacheResp caches the response from dctx.
	cacheResp(dctx *DNSContext)
}

// type check
var _ cachingResolver = (*Proxy)(nil)

// optimisticResolver is used to eventually resolve expired cached requests.
type optimisticResolver struct {
	reqs *sync.Map
	cr   cachingResolver
}

// newOptimisticResolver returns the new resolver for expired cached requests.
// cr must not be nil.
func newOptimisticResolver(cr cachingResolver) (s *optimisticResolver) {
	return &optimisticResolver{
		reqs: &sync.Map{},
		cr:   cr,
	}
}

// unit is a convenient alias for struct{}.
type unit = struct{}

// resolveOnce tries to resolve the request from dctx but only a single request
// with the same key at the same period of time.  It runs in a separate
// goroutine.  Do not pass the *DNSContext which is used elsewhere since it
// isn't intended to be used concurrently.
func (s *optimisticResolver) resolveOnce(dctx *DNSContext, key []byte, l *slog.Logger) {
	defer slogutil.RecoverAndLog(context.TODO(), l)

	keyHexed := hex.EncodeToString(key)
	if _, ok := s.reqs.LoadOrStore(keyHexed, unit{}); ok {
		return
	}
	defer s.reqs.Delete(keyHexed)

	ok, err := s.cr.replyFromUpstream(dctx)
	if err != nil {
		l.Debug("resolving request for optimistic cache", slogutil.KeyError, err)
	}

	if ok {
		s.cr.cacheResp(dctx)
	}
}
07070100000067000081A4000000000000000000000001679A649F00000C0A000000000000000000000000000000000000003100000000dnsproxy-0.75.0/proxy/optimisticresolver_test.gopackage proxy

import (
	"bytes"
	"log/slog"
	"sync"
	"testing"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/stretchr/testify/assert"
)

// testCachingResolver is a stub implementation of the cachingResolver interface
// to simplify testing.
type testCachingResolver struct {
	onReplyFromUpstream func(dctx *DNSContext) (ok bool, err error)
	onCacheResp         func(dctx *DNSContext)
}

// replyFromUpstream implements the cachingResolver interface for
// *testCachingResolver.
func (tcr *testCachingResolver) replyFromUpstream(dctx *DNSContext) (ok bool, err error) {
	return tcr.onReplyFromUpstream(dctx)
}

// cacheResp implements the cachingResolver interface for *testCachingResolver.
func (tcr *testCachingResolver) cacheResp(dctx *DNSContext) {
	tcr.onCacheResp(dctx)
}

func TestOptimisticResolver_ResolveOnce(t *testing.T) {
	in, out := make(chan unit), make(chan unit)
	var timesResolved, timesSet int

	tcr := &testCachingResolver{
		onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) {
			timesResolved++

			return true, nil
		},
		onCacheResp: func(_ *DNSContext) {
			timesSet++

			// Pass the signal to begin running secondary goroutines.
			out <- unit{}
			// Block until all the secondary goroutines finish.
			<-in
		},
	}

	s := newOptimisticResolver(tcr)
	sameKey := []byte{1, 2, 3}

	// Start the primary goroutine.
	go s.resolveOnce(nil, sameKey, slogutil.NewDiscardLogger())
	// Block until the primary goroutine reaches the resolve function.
	<-out

	wg := &sync.WaitGroup{}

	const secondaryNum = 10
	wg.Add(secondaryNum)
	for range secondaryNum {
		go func() {
			defer wg.Done()

			s.resolveOnce(nil, sameKey, slogutil.NewDiscardLogger())
		}()
	}

	// Wait until all the secondary goroutines are finished.
	wg.Wait()
	// Pass the signal to terminate the primary goroutine.
	in <- unit{}

	assert.Equal(t, 1, timesResolved)
	assert.Equal(t, 1, timesSet)
}

func TestOptimisticResolver_ResolveOnce_unsuccessful(t *testing.T) {
	key := []byte{1, 2, 3}

	t.Run("error", func(t *testing.T) {
		// TODO(d.kolyshev): Consider adding mock handler to golibs.
		logOutput := &bytes.Buffer{}
		l := slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{
			AddSource:   false,
			Level:       slog.LevelDebug,
			ReplaceAttr: nil,
		}))

		const rErr errors.Error = "sample resolving error"

		cached := false
		s := newOptimisticResolver(&testCachingResolver{
			onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return true, rErr },
			onCacheResp:         func(_ *DNSContext) { cached = true },
		})
		s.resolveOnce(nil, key, l)

		assert.True(t, cached)
		assert.Contains(t, logOutput.String(), rErr.Error())
	})

	t.Run("not_ok", func(t *testing.T) {
		cached := false
		s := newOptimisticResolver(&testCachingResolver{
			onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return false, nil },
			onCacheResp:         func(_ *DNSContext) { cached = true },
		})
		s.resolveOnce(nil, key, slogutil.NewDiscardLogger())

		assert.False(t, cached)
	})
}
07070100000068000081A4000000000000000000000001679A649F000050F9000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/proxy.go// Package proxy implements a DNS proxy that supports all known DNS encryption
// protocols.
package proxy

import (
	"cmp"
	"context"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/http"
	"net/netip"
	"slices"
	"sync"
	"sync/atomic"
	"time"

	"github.com/AdguardTeam/dnsproxy/fastip"
	"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/service"
	"github.com/AdguardTeam/golibs/syncutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/miekg/dns"
	gocache "github.com/patrickmn/go-cache"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
	"golang.org/x/exp/rand"
)

const (
	defaultTimeout   = 10 * time.Second
	minDNSPacketSize = 12 + 5
)

// Proto is the DNS protocol.
type Proto string

// Proto values.
const (
	// ProtoUDP is the plain DNS-over-UDP protocol.
	ProtoUDP Proto = "udp"
	// ProtoTCP is the plain DNS-over-TCP protocol.
	ProtoTCP Proto = "tcp"
	// ProtoTLS is the DNS-over-TLS (DoT) protocol.
	ProtoTLS Proto = "tls"
	// ProtoHTTPS is the DNS-over-HTTPS (DoH) protocol.
	ProtoHTTPS Proto = "https"
	// ProtoQUIC is the DNS-over-QUIC (DoQ) protocol.
	ProtoQUIC Proto = "quic"
	// ProtoDNSCrypt is the DNSCrypt protocol.
	ProtoDNSCrypt Proto = "dnscrypt"
)

// Proxy combines the proxy server state and configuration.
//
// TODO(a.garipov): Consider extracting conf blocks for better fieldalignment.
type Proxy struct {
	// requestsSema limits the number of simultaneous requests.
	//
	// TODO(a.garipov): Currently we have to pass this exact semaphore to the
	// workers, to prevent races on restart.  In the future we will need a
	// better restarting mechanism that completely prevents such invalid states.
	//
	// See also: https://github.com/AdguardTeam/AdGuardHome/issues/2242.
	requestsSema syncutil.Semaphore

	// privateNets determines if the requested address and the client address
	// are private.
	privateNets netutil.SubnetSet

	// time provides the current time.
	//
	// TODO(e.burkov):  Consider configuring it.
	time clock

	// randSrc provides the source of randomness.
	//
	// TODO(e.burkov):  Consider configuring it.
	randSrc rand.Source

	// messages constructs DNS messages.
	messages MessageConstructor

	// beforeRequestHandler handles the request's context before it is resolved.
	beforeRequestHandler BeforeRequestHandler

	// dnsCryptServer serves DNSCrypt queries.
	dnsCryptServer *dnscrypt.Server

	// logger is used for logging in the proxy service.  It is never nil.
	logger *slog.Logger

	// ratelimitBuckets is a storage for ratelimiters for individual IPs.
	ratelimitBuckets *gocache.Cache

	// fastestAddr finds the fastest IP address for the resolved domain.
	fastestAddr *fastip.FastestAddr

	// cache is used to cache requests.  It is disabled if nil.
	//
	// TODO(d.kolyshev): Move this cache to [Proxy.UpstreamConfig] field.
	cache *cache

	// shortFlighter is used to resolve the expired cached requests without
	// repetitions.
	shortFlighter *optimisticResolver

	// recDetector detects recursive requests that may appear when resolving
	// requests for private addresses.
	recDetector *recursionDetector

	// bytesPool is a pool of byte slices used to read DNS packets.
	//
	// TODO(e.burkov):  Use [syncutil.Pool].
	bytesPool *sync.Pool

	// udpListen are the listened UDP connections.
	udpListen []*net.UDPConn

	// tcpListen are the listened TCP connections.
	tcpListen []net.Listener

	// tlsListen are the listened TCP connections with TLS.
	tlsListen []net.Listener

	// quicListen are the listened QUIC connections.
	quicListen []*quic.EarlyListener

	// quicConns are UDP connections for all listened QUIC connections.  These
	// should be closed on shutdown, since *quic.EarlyListener doesn't close
	// them.
	quicConns []*net.UDPConn

	// quicTransports are transports for all listened QUIC connections.  These
	// should be closed on shutdown, since *quic.EarlyListener doesn't close
	// them.
	quicTransports []*quic.Transport

	// httpsListen are the listened HTTPS connections.
	httpsListen []net.Listener

	// h3Listen are the listened HTTP/3 connections.
	h3Listen []*quic.EarlyListener

	// httpsServer serves queries received over HTTPS.
	httpsServer *http.Server

	// h3Server serves queries received over HTTP/3.
	h3Server *http3.Server

	// dnsCryptUDPListen are the listened UDP connections for DNSCrypt.
	dnsCryptUDPListen []*net.UDPConn

	// dnsCryptTCPListen are the listened TCP connections for DNSCrypt.
	dnsCryptTCPListen []net.Listener

	// upstreamRTTStats maps the upstream address to its round-trip time
	// statistics.  It's holds the statistics for all upstreams to perform a
	// weighted random selection when using the load balancing mode.
	upstreamRTTStats map[string]upstreamRTTStats

	// dns64Prefs is a set of NAT64 prefixes that are used to detect and
	// construct DNS64 responses.  The DNS64 function is disabled if it is
	// empty.
	dns64Prefs netutil.SliceSubnetSet

	// Config is the proxy configuration.
	//
	// TODO(a.garipov): Remove this embed and create a proper initializer.
	Config

	// udpOOBSize is the size of the out-of-band data for UDP connections.
	udpOOBSize int

	// counter counts message contexts created with [Proxy.newDNSContext].
	counter atomic.Uint64

	// RWMutex protects the whole proxy.
	//
	// TODO(e.burkov):  Find out what exactly it protects and name it properly.
	// Also make it a pointer.
	sync.RWMutex

	// ratelimitLock protects ratelimitBuckets.
	ratelimitLock sync.Mutex

	// rttLock protects upstreamRTTStats.
	//
	// TODO(e.burkov):  Make it a pointer.
	rttLock sync.Mutex

	// started indicates if the proxy has been started.
	started bool
}

// New creates a new Proxy with the specified configuration.  c must not be nil.
//
// TODO(e.burkov):  Cover with tests.
func New(c *Config) (p *Proxy, err error) {
	p = &Proxy{
		Config: *c,
		privateNets: cmp.Or[netutil.SubnetSet](
			c.PrivateSubnets,
			netutil.SubnetSetFunc(netutil.IsLocallyServed),
		),
		beforeRequestHandler: cmp.Or[BeforeRequestHandler](
			c.BeforeRequestHandler,
			noopRequestHandler{},
		),
		upstreamRTTStats: map[string]upstreamRTTStats{},
		rttLock:          sync.Mutex{},
		ratelimitLock:    sync.Mutex{},
		RWMutex:          sync.RWMutex{},
		bytesPool: &sync.Pool{
			New: func() any {
				// 2 bytes may be used to store packet length (see TCP/TLS).
				b := make([]byte, 2+dns.MaxMsgSize)

				return &b
			},
		},
		udpOOBSize: proxynetutil.UDPGetOOBSize(),
		time:       realClock{},
		messages: cmp.Or[MessageConstructor](
			c.MessageConstructor,
			dnsmsg.DefaultMessageConstructor{},
		),
		recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
	}

	if c.Logger != nil {
		p.logger = c.Logger
	} else {
		p.logger = slog.Default().With(slogutil.KeyPrefix, LogPrefix)
	}

	// TODO(e.burkov):  Validate config separately and add the contract to the
	// New function.
	err = p.validateConfig()
	if err != nil {
		return nil, err
	}

	// TODO(s.chzhen):  Consider moving to [Proxy.validateConfig].
	err = p.validateBasicAuth()
	if err != nil {
		return nil, fmt.Errorf("basic auth: %w", err)
	}

	p.initCache()

	if p.MaxGoroutines > 0 {
		p.logger.Info("max goroutines is set", "count", p.MaxGoroutines)

		p.requestsSema = syncutil.NewChanSemaphore(p.MaxGoroutines)
	} else {
		p.requestsSema = syncutil.EmptySemaphore{}
	}

	if p.UpstreamMode == "" {
		p.UpstreamMode = UpstreamModeLoadBalance
	} else if p.UpstreamMode == UpstreamModeFastestAddr {
		p.fastestAddr = fastip.New(&fastip.Config{
			Logger:          p.Logger,
			PingWaitTimeout: p.FastestPingTimeout,
		})
	}

	err = p.setupDNS64()
	if err != nil {
		return nil, fmt.Errorf("setting up DNS64: %w", err)
	}

	p.RatelimitWhitelist = slices.Clone(p.RatelimitWhitelist)
	slices.SortFunc(p.RatelimitWhitelist, netip.Addr.Compare)

	return p, nil
}

// validateBasicAuth validates the basic-auth mode settings if p.Config.Userinfo
// is set.
func (p *Proxy) validateBasicAuth() (err error) {
	conf := p.Config
	if conf.Userinfo == nil {
		return nil
	}

	if len(conf.HTTPSListenAddr) == 0 {
		return errors.Error("no https addrs")
	}

	return nil
}

// Returns true if proxy is started.  It is safe for concurrent use.
func (p *Proxy) isStarted() (ok bool) {
	p.RLock()
	defer p.RUnlock()

	return p.started
}

// type check
var _ service.Interface = (*Proxy)(nil)

// Start implements the [service.Interface] for *Proxy.
func (p *Proxy) Start(ctx context.Context) (err error) {
	p.logger.InfoContext(ctx, "starting dns proxy server")

	p.Lock()
	defer p.Unlock()

	if p.started {
		return errors.Error("server has been already started")
	}

	err = p.validateListenAddrs()
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return err
	}

	err = p.configureListeners(ctx)
	if err != nil {
		return fmt.Errorf("configuring listeners: %w", err)
	}

	p.startListeners()
	p.started = true

	return nil
}

// closeAll closes all closers and appends the occurred errors to errs.
func closeAll[C io.Closer](errs []error, closers ...C) (appended []error) {
	for _, c := range closers {
		err := c.Close()
		if err != nil {
			errs = append(errs, err)
		}
	}

	return errs
}

// Shutdown implements the [service.Interface] for *Proxy.
func (p *Proxy) Shutdown(ctx context.Context) (err error) {
	p.logger.InfoContext(ctx, "stopping server")

	p.Lock()
	defer p.Unlock()

	if !p.started {
		// TODO(a.garipov): Consider returning err.
		p.logger.WarnContext(ctx, "dns proxy server is not started")

		return nil
	}

	errs := closeAll(nil, p.tcpListen...)
	p.tcpListen = nil

	errs = closeAll(errs, p.udpListen...)
	p.udpListen = nil

	errs = closeAll(errs, p.tlsListen...)
	p.tlsListen = nil

	if p.httpsServer != nil {
		errs = closeAll(errs, p.httpsServer)
		p.httpsServer = nil

		// No need to close these since they're closed by httpsServer.Close().
		p.httpsListen = nil
	}

	if p.h3Server != nil {
		errs = closeAll(errs, p.h3Server)
		p.h3Server = nil
	}

	errs = closeAll(errs, p.h3Listen...)
	p.h3Listen = nil

	errs = closeAll(errs, p.quicListen...)
	p.quicListen = nil

	errs = closeAll(errs, p.quicTransports...)
	p.quicTransports = nil

	errs = closeAll(errs, p.quicConns...)
	p.quicConns = nil

	errs = closeAll(errs, p.dnsCryptUDPListen...)
	p.dnsCryptUDPListen = nil

	errs = closeAll(errs, p.dnsCryptTCPListen...)
	p.dnsCryptTCPListen = nil

	for _, u := range []*UpstreamConfig{
		p.UpstreamConfig,
		p.PrivateRDNSUpstreamConfig,
		p.Fallbacks,
	} {
		if u != nil {
			errs = closeAll(errs, u)
		}
	}

	p.started = false

	p.logger.InfoContext(ctx, "stopped dns proxy server")

	if len(errs) > 0 {
		return fmt.Errorf("stopping dns proxy server: %w", errors.Join(errs...))
	}

	return nil
}

// addrFunc provides the address from the given A.
type addrFunc[A any] func(l A) (addr net.Addr)

// collectAddrs returns the slice of network addresses of the given listeners
// using the given addrFunc.
func collectAddrs[A any](listeners []A, af addrFunc[A]) (addrs []net.Addr) {
	for _, l := range listeners {
		addrs = append(addrs, af(l))
	}

	return addrs
}

// Addrs returns all listen addresses for the specified proto or nil if the
// proxy does not listen to it.  proto must be one of [Proto]: [ProtoTCP],
// [ProtoUDP], [ProtoTLS], [ProtoHTTPS], [ProtoQUIC], or [ProtoDNSCrypt].
func (p *Proxy) Addrs(proto Proto) (addrs []net.Addr) {
	p.RLock()
	defer p.RUnlock()

	switch proto {
	case ProtoTCP:
		return collectAddrs(p.tcpListen, net.Listener.Addr)
	case ProtoTLS:
		return collectAddrs(p.tlsListen, net.Listener.Addr)
	case ProtoHTTPS:
		return collectAddrs(p.httpsListen, net.Listener.Addr)
	case ProtoUDP:
		return collectAddrs(p.udpListen, (*net.UDPConn).LocalAddr)
	case ProtoQUIC:
		return collectAddrs(p.quicListen, (*quic.EarlyListener).Addr)
	case ProtoDNSCrypt:
		// Using only UDP addrs here
		//
		// TODO(ameshkov): To do it better we should either do
		// ProtoDNSCryptTCP/ProtoDNSCryptUDP or we should change the
		// configuration so that it was not possible to set different ports for
		// TCP/UDP listeners.
		return collectAddrs(p.dnsCryptUDPListen, (*net.UDPConn).LocalAddr)
	default:
		panic("proto must be 'tcp', 'tls', 'https', 'quic', 'dnscrypt' or 'udp'")
	}
}

// firstAddr returns the network address of the first listener in the given
// listeners or nil using the given addrFunc.
func firstAddr[A any](listeners []A, af addrFunc[A]) (addr net.Addr) {
	if len(listeners) == 0 {
		return nil
	}

	return af(listeners[0])
}

// Addr returns the first listen address for the specified proto or nil if the
// proxy does not listen to it.  proto must be one of [Proto]: [ProtoTCP],
// [ProtoUDP], [ProtoTLS], [ProtoHTTPS], [ProtoQUIC], or [ProtoDNSCrypt].
func (p *Proxy) Addr(proto Proto) (addr net.Addr) {
	p.RLock()
	defer p.RUnlock()

	switch proto {
	case ProtoTCP:
		return firstAddr(p.tcpListen, net.Listener.Addr)
	case ProtoTLS:
		return firstAddr(p.tlsListen, net.Listener.Addr)
	case ProtoHTTPS:
		return firstAddr(p.httpsListen, net.Listener.Addr)
	case ProtoUDP:
		return firstAddr(p.udpListen, (*net.UDPConn).LocalAddr)
	case ProtoQUIC:
		return firstAddr(p.quicListen, (*quic.EarlyListener).Addr)
	case ProtoDNSCrypt:
		return firstAddr(p.dnsCryptUDPListen, (*net.UDPConn).LocalAddr)
	default:
		panic("proto must be 'tcp', 'tls', 'https', 'quic', 'dnscrypt' or 'udp'")
	}
}

// selectUpstreams returns the upstreams to use for the specified host.  It
// firstly considers custom upstreams if those aren't empty and then the
// configured ones.  The returned slice may be empty or nil.
func (p *Proxy) selectUpstreams(d *DNSContext) (upstreams []upstream.Upstream, isPrivate bool) {
	q := d.Req.Question[0]
	host := q.Name

	if d.RequestedPrivateRDNS != (netip.Prefix{}) || p.shouldStripDNS64(d.Req) {
		// Use private upstreams.
		private := p.PrivateRDNSUpstreamConfig
		if p.UsePrivateRDNS && d.IsPrivateClient && private != nil {
			// This may only be a PTR, SOA, and NS request.
			upstreams = private.getUpstreamsForDomain(host)
		}

		return upstreams, true
	}

	getUpstreams := (*UpstreamConfig).getUpstreamsForDomain
	if q.Qtype == dns.TypeDS {
		getUpstreams = (*UpstreamConfig).getUpstreamsForDS
	}

	if custom := d.CustomUpstreamConfig; custom != nil {
		// Try to use custom.
		upstreams = getUpstreams(custom.upstream, host)
		if len(upstreams) > 0 {
			return upstreams, false
		}
	}

	// Use configured.
	return getUpstreams(p.UpstreamConfig, host), false
}

// replyFromUpstream tries to resolve the request via configured upstream
// servers.  It returns true if the response actually came from an upstream.
func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) {
	req := d.Req

	upstreams, isPrivate := p.selectUpstreams(d)
	if len(upstreams) == 0 {
		d.Res = p.messages.NewMsgNXDOMAIN(req)

		return false, fmt.Errorf("selecting upstream: %w", upstream.ErrNoUpstreams)
	}

	if isPrivate {
		p.recDetector.add(d.Req)
	}

	src := "upstream"
	wrapped := upstreamsWithStats(upstreams)

	// Perform the DNS request.
	resp, u, err := p.exchangeUpstreams(req, wrapped)
	if dns64Ups := p.performDNS64(req, resp, wrapped); dns64Ups != nil {
		u = dns64Ups
	} else if p.isBogusNXDomain(resp) {
		p.logger.Debug("response contains bogus-nxdomain ip")
		resp = p.messages.NewMsgNXDOMAIN(req)
	}

	var wrappedFallbacks []upstream.Upstream
	if err != nil && !isPrivate && p.Fallbacks != nil {
		p.logger.Debug("using fallback", slogutil.KeyError, err)

		src = "fallback"

		// upstreams mustn't appear empty since they have been validated when
		// creating proxy.
		upstreams = p.Fallbacks.getUpstreamsForDomain(req.Question[0].Name)

		wrappedFallbacks = upstreamsWithStats(upstreams)
		resp, u, err = upstream.ExchangeParallel(wrappedFallbacks, req)
	}

	if err != nil {
		p.logger.Debug("resolving err", "src", src, slogutil.KeyError, err)
	}

	if resp != nil {
		p.logger.Debug("resolved", "src", src)
	}

	unwrapped, stats := collectQueryStats(p.UpstreamMode, u, wrapped, wrappedFallbacks)
	d.queryStatistics = stats

	p.handleExchangeResult(d, req, resp, unwrapped)

	return resp != nil, err
}

// handleExchangeResult handles the result after the upstream exchange.  It sets
// the response to d and sets the upstream that have resolved the request.  If
// the response is nil, it generates a server failure response.
func (p *Proxy) handleExchangeResult(d *DNSContext, req, resp *dns.Msg, u upstream.Upstream) {
	if resp == nil {
		d.Res = p.messages.NewMsgSERVFAIL(req)
		d.hasEDNS0 = false

		return
	}

	d.Upstream = u
	d.Res = resp

	p.setMinMaxTTL(resp)
	if len(req.Question) > 0 && len(resp.Question) == 0 {
		// Explicitly construct the question section since some upstreams may
		// respond with invalidly constructed messages which cause out-of-range
		// panics afterwards.
		//
		// See https://github.com/AdguardTeam/AdGuardHome/issues/3551.
		resp.Question = []dns.Question{req.Question[0]}
	}
}

// addDO adds EDNS0 RR if needed and sets DO bit of msg to true.
func addDO(msg *dns.Msg) {
	if o := msg.IsEdns0(); o != nil {
		if !o.Do() {
			o.SetDo()
		}

		return
	}

	msg.SetEdns0(defaultUDPBufSize, true)
}

// defaultUDPBufSize defines the default size of UDP buffer for EDNS0 RRs.
const defaultUDPBufSize = 2048

// Resolve is the default resolving method used by the DNS proxy to query
// upstream servers.  It expects dctx is filled with the request, the client's
func (p *Proxy) Resolve(dctx *DNSContext) (err error) {
	if p.EnableEDNSClientSubnet {
		dctx.processECS(p.EDNSAddr, p.logger)
	}

	dctx.calcFlagsAndSize()

	// Also don't lookup the cache for responses with DNSSEC checking disabled
	// since only validated responses are cached and those may be not the
	// desired result for user specifying CD flag.
	cacheWorks := p.cacheWorks(dctx)
	if cacheWorks {
		if p.replyFromCache(dctx) {
			// Complete the response from cache.
			dctx.scrub()

			return nil
		}

		// On cache miss request for DNSSEC from the upstream to cache it
		// afterwards.
		addDO(dctx.Req)
	}

	var ok bool
	ok, err = p.replyFromUpstream(dctx)

	// Don't cache the responses having CD flag, just like Dnsmasq does.  It
	// prevents the cache from being poisoned with unvalidated answers which may
	// differ from validated ones.
	//
	// See https://github.com/imp/dnsmasq/blob/770bce967cfc9967273d0acfb3ea018fb7b17522/src/forward.c#L1169-L1172.
	if cacheWorks && ok && !dctx.Res.CheckingDisabled {
		// Cache the response with DNSSEC RRs.
		p.cacheResp(dctx)
	}

	// It is possible that the response is nil if the upstream hasn't been
	// chosen.
	if dctx.Res != nil {
		filterMsg(dctx.Res, dctx.Res, dctx.adBit, dctx.doBit, 0)
	}

	// Complete the response.
	dctx.scrub()

	if p.ResponseHandler != nil {
		p.ResponseHandler(dctx, err)
	}

	return err
}

// cacheWorks returns true if the cache works for the given context.  If not, it
// returns false and logs the reason why.
func (p *Proxy) cacheWorks(dctx *DNSContext) (ok bool) {
	var reason string
	switch {
	case p.cache == nil:
		reason = "disabled"
	case dctx.RequestedPrivateRDNS != netip.Prefix{}:
		// Don't cache the requests intended for local upstream servers, those
		// should be fast enough as is.
		reason = "requested address is private"
	case dctx.CustomUpstreamConfig != nil && dctx.CustomUpstreamConfig.cache == nil:
		// In case of custom upstream cache is not configured, the global proxy
		// cache cannot be used because different upstreams can return different
		// results.
		//
		// See https://github.com/AdguardTeam/dnsproxy/issues/169.
		//
		// TODO(e.burkov):  It probably should be decided after resolve.
		reason = "custom upstreams cache is not configured"
	case dctx.Req.CheckingDisabled:
		reason = "dnssec check disabled"
	default:
		return true
	}

	p.logger.Debug("not caching", "reason", reason)

	return false
}

// processECS adds EDNS Client Subnet data into the request from d.
func (dctx *DNSContext) processECS(cliIP net.IP, l *slog.Logger) {
	if ecs, _ := ecsFromMsg(dctx.Req); ecs != nil {
		if ones, _ := ecs.Mask.Size(); ones != 0 {
			dctx.ReqECS = ecs

			l.Debug("passing through ecs", "subnet", dctx.ReqECS)

			return
		}
	}

	var cliAddr netip.Addr
	if cliIP == nil {
		cliAddr = dctx.Addr.Addr()
		cliIP = cliAddr.AsSlice()
	} else {
		cliAddr, _ = netip.AddrFromSlice(cliIP)
	}

	if !netutil.IsSpecialPurpose(cliAddr) {
		// A Stub Resolver MUST set SCOPE PREFIX-LENGTH to 0.  See RFC 7871
		// Section 6.
		dctx.ReqECS = setECS(dctx.Req, cliIP, 0)

		l.Debug("setting ecs", "subnet", dctx.ReqECS)
	}
}
07070100000069000081A4000000000000000000000001679A649F0000A28D000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/proxy_test.gopackage proxy

import (
	"context"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"math/big"
	"net"
	"net/netip"
	"net/url"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
	"github.com/AdguardTeam/dnsproxy/upstream"
	glcache "github.com/AdguardTeam/golibs/cache"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

const (
	listenIP                = "127.0.0.1"
	testDefaultUpstreamAddr = "8.8.8.8:53"
	tlsServerName           = "testdns.adguard.com"
	testMessagesCount       = 10

	// defaultTestTTL used to guarantee caching.
	defaultTestTTL = 1000

	// testTimeout is the default timeout for tests.
	testTimeout = 500 * time.Millisecond
)

// localhostAnyPort is a [netip.AddrPort] having a value of 127.0.0.1:0.
var localhostAnyPort = netip.MustParseAddrPort(netutil.JoinHostPort(listenIP, 0))

// defaultTrustedProxies is a set of trusted proxies that includes all possible
// IP addresses.
var defaultTrustedProxies netutil.SubnetSet = netutil.SliceSubnetSet{
	netip.MustParsePrefix("0.0.0.0/0"),
	netip.MustParsePrefix("::0/0"),
}

// mustNew wraps [New] function failing the test on error.
func mustNew(t *testing.T, conf *Config) (p *Proxy) {
	t.Helper()

	p, err := New(conf)
	require.NoError(t, err)

	return p
}

// sendTestMessages sends [testMessagesCount] DNS requests to the specified
// connection and checks the responses.
func sendTestMessages(t *testing.T, conn *dns.Conn) {
	for i := range testMessagesCount {
		req := newTestMessage()
		err := conn.WriteMsg(req)
		require.NoErrorf(t, err, "req number %d", i)

		res, err := conn.ReadMsg()
		require.NoErrorf(t, err, "resp number %d", i)

		requireResponse(t, req, res)
	}
}

func newTestMessage() *dns.Msg {
	return newHostTestMessage("google-public-dns-a.google.com")
}

func newHostTestMessage(host string) (req *dns.Msg) {
	return &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id:               dns.Id(),
			RecursionDesired: true,
		},
		Question: []dns.Question{{
			Name:   host + ".",
			Qtype:  dns.TypeA,
			Qclass: dns.ClassINET,
		}},
	}
}

func requireResponse(t testing.TB, req, reply *dns.Msg) {
	t.Helper()

	require.NotNil(t, reply)
	require.Lenf(t, reply.Answer, 1, "wrong number of answers: %d", len(reply.Answer))
	require.Equal(t, req.Id, reply.Id)

	a, ok := reply.Answer[0].(*dns.A)
	require.Truef(t, ok, "wrong answer type: %v", reply.Answer[0])

	require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}

func newTLSConfig(t *testing.T) (conf *tls.Config, certPem []byte) {
	t.Helper()

	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
	require.NoError(t, err)

	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	require.NoError(t, err)

	notBefore := time.Now()
	notAfter := notBefore.Add(5 * 365 * time.Hour * 24)

	keyUsage := x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign
	template := x509.Certificate{
		SerialNumber:          serialNumber,
		Subject:               pkix.Name{Organization: []string{"AdGuard Tests"}},
		NotBefore:             notBefore,
		NotAfter:              notAfter,
		KeyUsage:              keyUsage,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
		IsCA:                  true,
		DNSNames:              []string{tlsServerName},
	}

	derBytes, err := x509.CreateCertificate(
		rand.Reader,
		&template,
		&template,
		&privateKey.PublicKey,
		privateKey,
	)
	require.NoError(t, err)

	certPem = pem.EncodeToMemory(&pem.Block{
		Type:  "CERTIFICATE",
		Bytes: derBytes,
	})
	keyPem := pem.EncodeToMemory(&pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
	})

	cert, err := tls.X509KeyPair(certPem, keyPem)
	require.NoError(t, err)

	return &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: tlsServerName}, certPem
}

// firstIP returns the first IP address from the DNS response.
func firstIP(resp *dns.Msg) (ip net.IP) {
	for _, ans := range resp.Answer {
		a, ok := ans.(*dns.A)
		if !ok {
			continue
		}

		return a.A
	}

	return nil
}

type testUpstream struct {
	ans []dns.RR

	ecsIP      net.IP
	ecsReqIP   net.IP
	ecsReqMask int
}

// type check
var _ upstream.Upstream = (*testUpstream)(nil)

// Exchange implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
	resp = &dns.Msg{}
	resp.SetReply(m)

	if u.ans != nil {
		resp.Answer = append(resp.Answer, u.ans...)
	}

	ecs, _ := ecsFromMsg(m)
	if ecs != nil {
		u.ecsReqIP = ecs.IP
		u.ecsReqMask, _ = ecs.Mask.Size()
	}
	if u.ecsIP != nil {
		setECS(resp, u.ecsIP, 24)
	}

	return resp, nil
}

// Address implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Address() (addr string) {
	return ""
}

// Close implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Close() (err error) {
	return nil
}

// newTestUpstreamConfigWithBoot creates a new UpstreamConfig with upstream
// addresses and a bootstrapped resolver.
func newTestUpstreamConfigWithBoot(
	t require.TestingT,
	timeout time.Duration,
	addrs ...string,
) (u *UpstreamConfig) {
	googleRslv, err := upstream.NewUpstreamResolver(
		"8.8.8.8:53",
		&upstream.Options{
			Logger:  slogutil.NewDiscardLogger(),
			Timeout: timeout,
		},
	)
	require.NoError(t, err)

	upsConf, err := ParseUpstreamsConfig(addrs, &upstream.Options{
		Logger:    slogutil.NewDiscardLogger(),
		Timeout:   timeout,
		Bootstrap: upstream.NewCachingResolver(googleRslv),
	})
	require.NoError(t, err)

	return upsConf
}

// newTestUpstreamConfig creates a new UpstreamConfig with a single upstream
// address and default timeout.
func newTestUpstreamConfig(
	t testing.TB,
	timeout time.Duration,
	addrs ...string,
) (u *UpstreamConfig) {
	t.Helper()

	upsConf, err := ParseUpstreamsConfig(addrs, &upstream.Options{
		Logger:  slogutil.NewDiscardLogger(),
		Timeout: timeout,
	})
	require.NoError(t, err)

	return upsConf
}

// mustStartDefaultProxy starts a new proxy with default settings and returns
// it.  It fails the test on error.
func mustStartDefaultProxy(t *testing.T) (p *Proxy) {
	t.Helper()

	p = mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
	})

	ctx := context.Background()
	err := p.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) })

	return p
}

// TestProxyRace sends multiple parallel DNS requests to the
// fully configured dnsproxy to check for race conditions
func TestProxyRace(t *testing.T) {
	upsConf := newTestUpstreamConfig(
		t,
		defaultTimeout,
		// Use the same upstream twice so that we could rotate them
		testDefaultUpstreamAddr,
		testDefaultUpstreamAddr,
	)
	dnsProxy := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         upsConf,
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
	})

	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	conn, err := dns.Dial("udp", addr.String())
	require.NoError(t, err)

	g := &sync.WaitGroup{}
	g.Add(testMessagesCount)

	pt := testutil.PanicT{}
	for range testMessagesCount {
		go func() {
			defer g.Done()

			req := newTestMessage()
			writeErr := conn.WriteMsg(req)
			require.NoError(pt, writeErr)

			res, readErr := conn.ReadMsg()
			require.NoError(pt, readErr)

			// We do not check if msg IDs match because the order of responses may
			// be different.

			require.NotNil(pt, res)
			require.Len(pt, res.Answer, 1)
			require.IsType(pt, &dns.A{}, res.Answer[0])

			a := res.Answer[0].(*dns.A)
			require.Equal(pt, net.IPv4(8, 8, 8, 8), a.A.To16())
		}()
	}

	g.Wait()
}

// newTxts returns new test TXT RR strings.
func newTxts(t *testing.T, txtDataLen int) (txts []string) {
	t.Helper()

	const txtDataChunkLen = 255

	txtDataChunkNum := txtDataLen / txtDataChunkLen
	if txtDataLen%txtDataChunkLen > 0 {
		txtDataChunkNum++
	}

	txts = make([]string, txtDataChunkNum)
	randData := make([]byte, txtDataLen)
	n, err := rand.Read(randData)
	require.NoError(t, err)
	require.Equal(t, txtDataLen, n)

	for i, c := range randData {
		randData[i] = c%26 + 'a'
	}

	// *dns.TXT requires splitting the actual data into 256-byte chunks.
	for i := range txtDataChunkNum {
		r := txtDataChunkLen * (i + 1)
		if r > txtDataLen {
			r = txtDataLen
		}
		txts[i] = string(randData[txtDataChunkLen*i : r])
	}

	return txts
}

// newDNSContext returns new DNS request message context with Proto set to
// [ProtoUDP].  Constructs request message from the given parameters.
func newDNSContext(
	domain string,
	qtype uint16,
	qclass uint16,
	edns bool,
	udpsize uint16,
) (dctx *DNSContext) {
	req := newReq(domain, qtype, qclass)
	if edns {
		req.SetEdns0(udpsize, true)
	}

	return &DNSContext{
		Req:   req,
		Proto: ProtoUDP,
	}
}

// newReq returns new request message for provided parameters.
func newReq(domain string, qtype, qclass uint16) (req *dns.Msg) {
	return &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id: dns.Id(),
		},
		Compress: true,
		Question: []dns.Question{{
			Name:   dns.Fqdn(domain),
			Qtype:  qtype,
			Qclass: qclass,
		}},
	}
}

func TestProxy_Resolve_dnssecCache(t *testing.T) {
	const (
		host = "example.com"

		// Larger than UDP buffer size to invoke truncation.
		txtDataLen = 1024
	)

	txt := &dns.TXT{
		Hdr: dns.RR_Header{
			Name:   dns.Fqdn(host),
			Rrtype: dns.TypeTXT,
			Class:  dns.ClassINET,
		},
		Txt: newTxts(t, txtDataLen),
	}

	a := &dns.A{
		Hdr: dns.RR_Header{
			Name:   dns.Fqdn(host),
			Rrtype: dns.TypeA,
			Class:  dns.ClassINET,
		},
		A: net.IP{1, 2, 3, 4},
	}

	ds := &dns.DS{
		Hdr: dns.RR_Header{
			Name:   dns.Fqdn(host),
			Rrtype: dns.TypeDS,
			Class:  dns.ClassINET,
		},
		Digest: "736f6d652064656c65676174696f6e207369676e6572",
	}

	rrsig := &dns.RRSIG{
		Hdr: dns.RR_Header{
			Name:   dns.Fqdn(host),
			Rrtype: dns.TypeRRSIG,
			Class:  dns.ClassINET,
			Ttl:    defaultTestTTL,
		},
		TypeCovered: dns.TypeA,
		Algorithm:   8,
		Labels:      2,
		SignerName:  dns.Fqdn(host),
		Signature:   "c29tZSBycnNpZyByZWxhdGVkIHN0dWZm",
	}

	u := &fakeUpstream{
		onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
			resp = (&dns.Msg{}).SetReply(m)

			q := m.Question[0]
			switch q.Qtype {
			case dns.TypeA:
				resp.Answer = append(resp.Answer, a)
			case dns.TypeTXT:
				resp.Answer = append(resp.Answer, txt)
			case dns.TypeDS:
				resp.Answer = append(resp.Answer, ds)
			default:
				// Go on.  The RRSIG resource record is added afterward.  This
				// upstream.Upstream implementation doesn't handle explicit
				// requests for it.
			}

			if len(resp.Answer) > 0 {
				resp.Answer[0].Header().Ttl = defaultTestTTL
			}

			if o := m.IsEdns0(); o != nil {
				resp.Answer = append(resp.Answer, rrsig)
				resp.SetEdns0(defaultUDPBufSize, o.Do())
			}

			return resp, nil
		},
		onAddress: func() (addr string) { return "" },
		onClose:   func() (err error) { return nil },
	}

	p := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         &UpstreamConfig{Upstreams: []upstream.Upstream{u}},
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		CacheEnabled:           true,
		CacheSizeBytes:         defaultCacheSize,
	})

	testCases := []struct {
		wantAns dns.RR
		name    string
		wantLen int
		edns    bool
	}{{
		wantAns: a,
		name:    "a_noedns",
		wantLen: 1,
		edns:    false,
	}, {
		wantAns: a,
		name:    "a_ends",
		wantLen: 2,
		edns:    true,
	}, {
		wantAns: txt,
		name:    "txt_noedns",
		wantLen: 1,
		edns:    false,
	}, {
		wantAns: txt,
		name:    "txt_edns",
		// Truncated.
		wantLen: 0,
		edns:    true,
	}, {
		wantAns: ds,
		name:    "ds_noedns",
		wantLen: 1,
		edns:    false,
	}, {
		wantAns: ds,
		name:    "ds_edns",
		wantLen: 2,
		edns:    true,
	}}

	for _, tc := range testCases {
		ansHdr := tc.wantAns.Header()
		dctx := newDNSContext(ansHdr.Name, ansHdr.Rrtype, ansHdr.Class, tc.edns, txtDataLen/2)

		t.Run(tc.name, func(t *testing.T) {
			t.Cleanup(p.cache.items.Clear)

			err := p.Resolve(dctx)
			require.NoError(t, err)

			res := dctx.Res
			require.NotNil(t, res)

			require.Len(t, res.Answer, tc.wantLen, res.Answer)
			switch tc.wantLen {
			case 0:
				assert.True(t, res.Truncated)
			case 1:
				res.Answer[0].Header().Ttl = defaultTestTTL
				assert.Equal(t, tc.wantAns, res.Answer[0])
			case 2:
				res.Answer[0].Header().Ttl = defaultTestTTL
				assert.Equal(t, tc.wantAns, res.Answer[0])
				assert.Equal(t, rrsig, res.Answer[1])
			default:
				t.Fatalf("wanted length has unexpected value %d", tc.wantLen)
			}

			cached, expired, key := p.cache.get(dctx.Req)
			require.NotNil(t, cached)
			require.Len(t, cached.m.Answer, 2)

			assert.False(t, expired)
			assert.Equal(t, key, msgToKey(dctx.Req))

			// Just make it match.
			cached.m.Answer[0].Header().Ttl = defaultTestTTL
			assert.Equal(t, tc.wantAns.String(), cached.m.Answer[0].String())
			assert.Equal(t, rrsig.String(), cached.m.Answer[1].String())
		})

	}
}

func TestExchangeWithReservedDomains(t *testing.T) {
	t.Parallel()

	dnsProxy := mustNew(t, &Config{
		Logger:        slogutil.NewDiscardLogger(),
		UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig: newTestUpstreamConfigWithBoot(
			t,
			testTimeout,
			"[/adguard.com/]1.2.3.4",
			"[/google.ru/]2.3.4.5",
			"[/maps.google.ru/]#",
			"1.1.1.1",
		),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
	})

	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// Create a DNS-over-TCP client connection.
	addr := dnsProxy.Addr(ProtoTCP)
	conn, err := dns.Dial("tcp", addr.String())
	require.NoError(t, err)

	// Create google-a test message.
	req := newTestMessage()
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	// Make sure that dnsproxy is working.
	res, err := conn.ReadMsg()
	require.NoError(t, err)
	requireResponse(t, req, res)

	// Create adguard.com test message.
	req = newHostTestMessage("adguard.com")
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	// Test message should not be resolved.
	res, _ = conn.ReadMsg()
	require.Nil(t, res.Answer)

	// Create www.google.ru test message.
	req = newHostTestMessage("www.google.ru")
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	// Test message should not be resolved.
	res, _ = conn.ReadMsg()
	require.Empty(t, res.Answer)

	// Create maps.google.ru test message.
	req = newHostTestMessage("maps.google.ru")
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	// Test message should be resolved.
	res, _ = conn.ReadMsg()
	require.NotNil(t, res.Answer)
}

// TestOneByOneUpstreamsExchange tries to resolve DNS request
// with one valid and two invalid upstreams
func TestOneByOneUpstreamsExchange(t *testing.T) {
	t.Parallel()

	dnsProxy := mustNew(t, &Config{
		Logger:        slogutil.NewDiscardLogger(),
		UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig: newTestUpstreamConfigWithBoot(
			t,
			testTimeout,
			"https://fake-dns.com/fake-dns-query",
			"tls://fake-dns.com",
			"1.1.1.1",
		),
		TrustedProxies:         defaultTrustedProxies,
		Fallbacks:              newTestUpstreamConfig(t, testTimeout, "1.2.3.4:567"),
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
	})

	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// create a DNS-over-TCP client connection
	addr := dnsProxy.Addr(ProtoTCP)
	conn, err := dns.Dial("tcp", addr.String())
	require.NoError(t, err)

	// make sure that the response is okay and resolved by valid upstream
	req := newTestMessage()
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	start := time.Now()
	res, err := conn.ReadMsg()
	require.NoError(t, err)
	requireResponse(t, req, res)

	elapsed := time.Since(start)
	assert.Greater(t, 3*testTimeout, elapsed)
}

// newLocalUpstreamListener creates a new localhost listener on the specified
// port for tcp4 network and returns its listening address.
func newLocalUpstreamListener(t *testing.T, port uint16, h dns.Handler) (real netip.AddrPort) {
	t.Helper()

	startCh := make(chan struct{})
	upsSrv := &dns.Server{
		Addr:              netip.AddrPortFrom(netutil.IPv4Localhost(), port).String(),
		Net:               "tcp",
		Handler:           h,
		NotifyStartedFunc: func() { close(startCh) },
	}
	go func() {
		err := upsSrv.ListenAndServe()
		require.NoError(testutil.PanicT{}, err)
	}()

	<-startCh
	testutil.CleanupAndRequireSuccess(t, upsSrv.Shutdown)

	return testutil.RequireTypeAssert[*net.TCPAddr](t, upsSrv.Listener.Addr()).AddrPort()
}

func TestFallback(t *testing.T) {
	t.Parallel()

	responseCh := make(chan uint16)
	failCh := make(chan uint16)

	successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
		testutil.RequireSend(testutil.PanicT{}, responseCh, r.Id, testTimeout)

		require.NoError(testutil.PanicT{}, w.WriteMsg((&dns.Msg{}).SetReply(r)))
	})
	failHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
		testutil.RequireSend(testutil.PanicT{}, failCh, r.Id, testTimeout)

		require.NoError(testutil.PanicT{}, w.WriteMsg(&dns.Msg{}))
	})

	successAddr := (&url.URL{
		Scheme: string(ProtoTCP),
		Host:   newLocalUpstreamListener(t, 0, successHandler).String(),
	}).String()
	alsoSuccessAddr := (&url.URL{
		Scheme: string(ProtoTCP),
		Host:   newLocalUpstreamListener(t, 0, successHandler).String(),
	}).String()
	failAddr := (&url.URL{
		Scheme: string(ProtoTCP),
		Host:   newLocalUpstreamListener(t, 0, failHandler).String(),
	}).String()

	dnsProxy := mustNew(t, &Config{
		Logger:        slogutil.NewDiscardLogger(),
		UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig: newTestUpstreamConfig(
			t,
			testTimeout,
			failAddr,
			"[/specific.example/]"+alsoSuccessAddr,
			// almost.failing.example will fall here first.
			"[/failing.example/]"+failAddr,
		),
		TrustedProxies: defaultTrustedProxies,
		Fallbacks: newTestUpstreamConfig(
			t,
			testTimeout,
			failAddr,
			successAddr,
			"[/failing.example/]"+failAddr,
			"[/almost.failing.example/]"+alsoSuccessAddr,
		),
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
	})

	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	conn, err := dns.Dial("tcp", dnsProxy.Addr(ProtoTCP).String())
	require.NoError(t, err)

	testCases := []struct {
		name        string
		wantSignals []chan uint16
	}{{
		name: "general.example",
		wantSignals: []chan uint16{
			failCh,
			// Both non-specific fallbacks tried.
			failCh,
			responseCh,
		},
	}, {
		name: "specific.example",
		wantSignals: []chan uint16{
			responseCh,
		},
	}, {
		name: "failing.example",
		wantSignals: []chan uint16{
			failCh,
			failCh,
		},
	}, {
		name: "almost.failing.example",
		wantSignals: []chan uint16{
			failCh,
			responseCh,
		},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			req := newHostTestMessage(tc.name)
			err = conn.WriteMsg(req)
			require.NoError(t, err)

			for _, ch := range tc.wantSignals {
				reqID, ok := testutil.RequireReceive(testutil.PanicT{}, ch, testTimeout)
				require.True(t, ok)

				assert.Equal(t, req.Id, reqID)
			}

			_, err = conn.ReadMsg()
			require.NoError(t, err)
		})
	}
}

func TestFallbackFromInvalidBootstrap(t *testing.T) {
	t.Parallel()

	invalidRslv, err := upstream.NewUpstreamResolver("8.8.8.8:555", &upstream.Options{
		Logger:  slogutil.NewDiscardLogger(),
		Timeout: testTimeout,
	})
	require.NoError(t, err)

	// Prepare the proxy server
	upsConf, err := ParseUpstreamsConfig([]string{"tls://dns.adguard.com"}, &upstream.Options{
		Logger:    slogutil.NewDiscardLogger(),
		Bootstrap: invalidRslv, Timeout: testTimeout,
	})
	require.NoError(t, err)

	dnsProxy := mustNew(t, &Config{
		Logger:         slogutil.NewDiscardLogger(),
		UDPListenAddr:  []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:  []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig: upsConf,
		TrustedProxies: defaultTrustedProxies,
		Fallbacks: newTestUpstreamConfig(
			t,
			testTimeout,
			"1.0.0.1",
			"8.8.8.8",
		),
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
	})

	// Start listening
	ctx := context.Background()
	err = dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	conn, err := dns.Dial("udp", addr.String())
	require.NoError(t, err)

	// Make sure that the response is okay and resolved by the fallback
	req := newTestMessage()
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	start := time.Now()
	res, err := conn.ReadMsg()
	require.NoError(t, err)
	requireResponse(t, req, res)

	elapsed := time.Since(start)
	assert.Greater(t, 3*testTimeout, elapsed)
}

func TestRefuseAny(t *testing.T) {
	dnsProxy := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		RefuseAny:              true,
	})

	// Start listening
	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{
		Net:     string(ProtoUDP),
		Timeout: testTimeout,
	}

	// Create a DNS request
	request := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id:               dns.Id(),
			RecursionDesired: true,
		},
	}).SetQuestion("google.com.", dns.TypeANY)

	r, _, err := client.Exchange(request, addr.String())
	require.NoError(t, err)

	assert.Equal(t, dns.RcodeNotImplemented, r.Rcode)
}

func TestInvalidDNSRequest(t *testing.T) {
	dnsProxy := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		RefuseAny:              true,
	})

	// Start listening
	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// Create a DNS-over-UDP client connection
	client := &dns.Client{
		Net:     string(ProtoUDP),
		Timeout: testTimeout,
	}

	// Create a DNS request
	request := &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id:               dns.Id(),
			RecursionDesired: true,
		},
	}

	r, _, err := client.Exchange(request, dnsProxy.Addr(ProtoUDP).String())
	require.NoError(t, err)
	assert.Equal(t, dns.RcodeServerFailure, r.Rcode)
}

// Server must drop incoming Response messages
func TestResponseInRequest(t *testing.T) {
	dnsProxy := mustStartDefaultProxy(t)

	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{
		Net:     string(ProtoUDP),
		Timeout: testTimeout,
	}

	req := newTestMessage()
	req.Response = true

	r, _, err := client.Exchange(req, addr.String())

	netErr := &net.OpError{}
	require.ErrorAs(t, err, &netErr)
	assert.True(t, netErr.Timeout())
	assert.Nil(t, r)
}

// Server must respond with SERVFAIL to requests without a Question
func TestNoQuestion(t *testing.T) {
	dnsProxy := mustStartDefaultProxy(t)

	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{
		Net:     string(ProtoUDP),
		Timeout: testTimeout,
	}

	req := newTestMessage()
	req.Question = nil

	r, _, err := client.Exchange(req, addr.String())
	require.NoError(t, err)
	assert.Equal(t, dns.RcodeServerFailure, r.Rcode)
}

// fakeUpstream is a mock upstream implementation to simplify testing.  It
// allows assigning custom Exchange and Address methods.
//
// TODO(e.burkov):  Use dnsproxytest.FakeUpstream instead.
type fakeUpstream struct {
	onExchange func(m *dns.Msg) (resp *dns.Msg, err error)
	onAddress  func() (addr string)
	onClose    func() (err error)
}

// type check
var _ upstream.Upstream = (*fakeUpstream)(nil)

// Exchange implements upstream.Upstream interface for *funcUpstream.
func (u *fakeUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return u.onExchange(m) }

// Address implements upstream.Upstream interface for *funcUpstream.
func (u *fakeUpstream) Address() (addr string) { return u.onAddress() }

// Close implements upstream.Upstream interface for *funcUpstream.
func (u *fakeUpstream) Close() (err error) { return u.onClose() }

func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) {
	dnsProxy := mustStartDefaultProxy(t)

	u := &fakeUpstream{
		onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
			resp = (&dns.Msg{}).SetReply(m)
			resp.Answer = append(resp.Answer, &dns.A{
				Hdr: dns.RR_Header{
					Name:   m.Question[0].Name,
					Class:  dns.ClassINET,
					Rrtype: dns.TypeA,
				},
				A: net.IP{1, 2, 3, 4},
			})
			// Make the response invalid.
			resp.Question = []dns.Question{}

			return resp, nil
		},
		onAddress: func() (addr string) { return "stub" },
		onClose:   func() error { panic("not implemented") },
	}

	d := &DNSContext{
		CustomUpstreamConfig: NewCustomUpstreamConfig(
			&UpstreamConfig{Upstreams: []upstream.Upstream{u}},
			false,
			0,
			false,
		),
		Req:  newHostTestMessage("host"),
		Addr: netip.MustParseAddrPort("1.2.3.0:1234"),
	}

	var err error
	require.NotPanics(t, func() {
		err = dnsProxy.Resolve(d)
	})
	require.NoError(t, err)

	assert.Equal(t, d.Req.Question[0], d.Res.Question[0])
}

func TestExchangeCustomUpstreamConfig(t *testing.T) {
	prx := mustStartDefaultProxy(t)

	ansIP := net.IP{4, 3, 2, 1}
	u := &testUpstream{
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Rrtype: dns.TypeA,
				Name:   "host.",
				Ttl:    60,
			},
			A: ansIP,
		}},
	}

	d := DNSContext{
		CustomUpstreamConfig: NewCustomUpstreamConfig(
			&UpstreamConfig{Upstreams: []upstream.Upstream{u}},
			false,
			0,
			false,
		),
		Req:  newHostTestMessage("host"),
		Addr: netip.MustParseAddrPort("1.2.3.0:1234"),
	}

	err := prx.Resolve(&d)
	require.NoError(t, err)

	assert.Equal(t, ansIP, firstIP(d.Res))
}

func TestExchangeCustomUpstreamConfigCache(t *testing.T) {
	prx := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		CacheEnabled:           true,
	})

	ctx := context.Background()
	err := prx.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return prx.Shutdown(ctx) })

	var count int

	ansIP := net.IP{4, 3, 2, 1}
	exchangeFunc := func(m *dns.Msg) (resp *dns.Msg, err error) {
		resp = &dns.Msg{}
		resp.SetReply(m)
		resp.Answer = append(resp.Answer, &dns.A{
			Hdr: dns.RR_Header{
				Name:   m.Question[0].Name,
				Class:  dns.ClassINET,
				Rrtype: dns.TypeA,
				Ttl:    defaultTestTTL,
			},
			A: ansIP,
		})

		count++

		return resp, nil
	}
	u := &fakeUpstream{
		onExchange: exchangeFunc,
		onAddress:  func() (addr string) { return "stub" },
		onClose:    func() error { panic("not implemented") },
	}

	customUpstreamConfig := NewCustomUpstreamConfig(
		&UpstreamConfig{Upstreams: []upstream.Upstream{u}},
		true,
		defaultCacheSize,
		prx.EnableEDNSClientSubnet,
	)

	d := DNSContext{
		CustomUpstreamConfig: customUpstreamConfig,
		Req:                  newHostTestMessage("host"),
		Addr:                 netip.MustParseAddrPort("1.2.3.0:1234"),
	}

	err = prx.Resolve(&d)
	require.NoError(t, err)

	require.Equal(t, 1, count)
	assert.Equal(t, ansIP, firstIP(d.Res))

	err = prx.Resolve(&d)
	require.NoError(t, err)

	assert.Equal(t, 1, count)
	assert.Equal(t, ansIP, firstIP(d.Res))

	customUpstreamConfig.ClearCache()

	err = prx.Resolve(&d)
	require.NoError(t, err)

	assert.Equal(t, 2, count)
	assert.Equal(t, ansIP, firstIP(d.Res))
}

func TestECS(t *testing.T) {
	t.Run("ipv4", func(t *testing.T) {
		ip := net.IP{1, 2, 3, 4}

		m := &dns.Msg{}
		subnet := setECS(m, ip, 16)

		ones, _ := subnet.Mask.Size()
		assert.Equal(t, 24, ones)

		var scope int
		subnet, scope = ecsFromMsg(m)
		assert.Equal(t, ip.Mask(subnet.Mask), subnet.IP)

		ones, _ = subnet.Mask.Size()
		assert.Equal(t, 24, ones)
		assert.Equal(t, 16, scope)
	})

	t.Run("ipv6", func(t *testing.T) {
		ip := net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}

		m := &dns.Msg{}
		subnet := setECS(m, ip, 48)

		ones, _ := subnet.Mask.Size()
		assert.Equal(t, 56, ones)

		var scope int
		subnet, scope = ecsFromMsg(m)
		assert.Equal(t, ip.Mask(subnet.Mask), subnet.IP)

		ones, _ = subnet.Mask.Size()
		assert.Equal(t, 56, ones)
		assert.Equal(t, 48, scope)
	})
}

// Resolve the same host with the different client subnet values
func TestECSProxy(t *testing.T) {
	var (
		ip1230 = net.IP{1, 2, 3, 0}
		ip2230 = net.IP{2, 2, 3, 0}
		ip4321 = net.IP{4, 3, 2, 1}
		ip4322 = net.IP{4, 3, 2, 2}
		ip4323 = net.IP{4, 3, 2, 3}
	)

	u := &testUpstream{
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60},
			A:   ip4321,
		}},
		ecsIP: ip1230,
	}

	prx := mustNew(t, &Config{
		Logger:        slogutil.NewDiscardLogger(),
		UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig: &UpstreamConfig{
			Upstreams: []upstream.Upstream{u},
		},
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		EnableEDNSClientSubnet: true,
		CacheEnabled:           true,
	})

	ctx := context.Background()
	err := prx.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return prx.Shutdown(ctx) })

	t.Run("cache_subnet", func(t *testing.T) {
		d := DNSContext{
			Req:  newHostTestMessage("host"),
			Addr: netip.MustParseAddrPort("1.2.3.0:1234"),
		}

		err = prx.Resolve(&d)
		require.NoError(t, err)

		assert.Equal(t, net.IP{4, 3, 2, 1}, firstIP(d.Res))
		assert.Equal(t, ip1230, u.ecsReqIP)
	})

	t.Run("serve_subnet_cache", func(t *testing.T) {
		d := &DNSContext{
			Req:  newHostTestMessage("host"),
			Addr: netip.MustParseAddrPort("1.2.3.1:1234"),
		}
		u.ans, u.ecsIP, u.ecsReqIP = nil, nil, nil

		require.NoError(t, prx.Resolve(d))

		assert.Equal(t, ip4321, firstIP(d.Res))
		assert.Nil(t, u.ecsReqIP)
	})

	t.Run("another_subnet", func(t *testing.T) {
		d := DNSContext{
			Req:  newHostTestMessage("host"),
			Addr: netip.MustParseAddrPort("2.2.3.0:1234"),
		}
		u.ans = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60},
			A:   ip4322,
		}}
		u.ecsIP = ip2230

		err = prx.Resolve(&d)
		require.NoError(t, err)

		assert.Equal(t, ip4322, firstIP(d.Res))
		assert.Equal(t, ip2230, u.ecsReqIP)
	})

	t.Run("cache_general", func(t *testing.T) {
		d := DNSContext{
			Req:  newHostTestMessage("host"),
			Addr: netip.MustParseAddrPort("127.0.0.1:1234"),
		}
		u.ans = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60},
			A:   ip4323,
		}}
		u.ecsIP, u.ecsReqIP = nil, nil

		err = prx.Resolve(&d)
		require.NoError(t, err)

		assert.Equal(t, ip4323, firstIP(d.Res))
		assert.Nil(t, u.ecsReqIP)
	})

	t.Run("serve_general_cache", func(t *testing.T) {
		d := DNSContext{
			Req:  newHostTestMessage("host"),
			Addr: netip.MustParseAddrPort("127.0.0.2:1234"),
		}
		u.ans, u.ecsIP, u.ecsReqIP = nil, nil, nil

		err = prx.Resolve(&d)
		require.NoError(t, err)

		assert.Equal(t, ip4323, firstIP(d.Res))
		assert.Nil(t, u.ecsReqIP)
	})
}

func TestECSProxyCacheMinMaxTTL(t *testing.T) {
	clientIP := net.IP{1, 2, 3, 0}
	u := &testUpstream{
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Rrtype: dns.TypeA,
				Name:   "host.",
				Ttl:    10,
			},
			A: net.IP{4, 3, 2, 1},
		}},
		ecsIP: clientIP,
	}

	prx := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         &UpstreamConfig{Upstreams: []upstream.Upstream{u}},
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		EnableEDNSClientSubnet: true,
		CacheEnabled:           true,
		CacheMinTTL:            20,
		CacheMaxTTL:            40,
	})

	ctx := context.Background()
	err := prx.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return prx.Shutdown(ctx) })

	// first request
	d := DNSContext{
		Req:  newHostTestMessage("host"),
		Addr: netip.MustParseAddrPort("1.2.3.0:1234"),
	}
	err = prx.Resolve(&d)
	require.NoError(t, err)

	// get from cache - check min TTL
	ci, expired, key := prx.cache.getWithSubnet(d.Req, &net.IPNet{
		IP:   clientIP,
		Mask: net.CIDRMask(24, netutil.IPv4BitLen),
	})
	assert.False(t, expired)

	assert.Equal(t, key, msgToKeyWithSubnet(d.Req, clientIP, 24))
	assert.True(t, ci.m.Answer[0].Header().Ttl == prx.CacheMinTTL)

	// 2nd request
	clientIP = net.IP{1, 2, 4, 0}
	d.Req = newHostTestMessage("host")
	d.Addr = netip.MustParseAddrPort("1.2.4.0:1234")
	u.ans = []dns.RR{&dns.A{
		Hdr: dns.RR_Header{
			Rrtype: dns.TypeA,
			Name:   "host.",
			Ttl:    60,
		},
		A: net.IP{4, 3, 2, 1},
	}}
	u.ecsIP = clientIP
	err = prx.Resolve(&d)
	require.NoError(t, err)

	// get from cache - check max TTL
	ci, expired, key = prx.cache.getWithSubnet(d.Req, &net.IPNet{
		IP:   clientIP,
		Mask: net.CIDRMask(24, netutil.IPv4BitLen),
	})
	assert.False(t, expired)
	assert.Equal(t, key, msgToKeyWithSubnet(d.Req, clientIP, 24))
	assert.True(t, ci.m.Answer[0].Header().Ttl == prx.CacheMaxTTL)
}

func TestProxy_Resolve_withOptimisticResolver(t *testing.T) {
	const (
		host             = "some.domain.name."
		nonOptimisticTTL = 3600
	)

	buildCtx := func() (dctx *DNSContext) {
		req := &dns.Msg{
			MsgHdr: dns.MsgHdr{Id: dns.Id()},
			Question: []dns.Question{{
				Name:   host,
				Qtype:  dns.TypeA,
				Qclass: dns.ClassINET,
			}},
		}

		return &DNSContext{Req: req}
	}
	buildResp := func(req *dns.Msg, ttl uint32) (resp *dns.Msg) {
		resp = (&dns.Msg{}).SetReply(req)
		resp.Answer = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Name:   host,
				Rrtype: dns.TypeA,
				Class:  dns.ClassINET,
				Ttl:    ttl,
			},
			A: net.IP{1, 2, 3, 4},
		}}

		return resp
	}

	p := &Proxy{
		Config: Config{
			CacheEnabled:    true,
			CacheOptimistic: true,
		},
		logger: slogutil.NewDiscardLogger(),
	}

	p.initCache()
	out, in := make(chan unit), make(chan unit)
	p.shortFlighter.cr = &testCachingResolver{
		onReplyFromUpstream: func(dctx *DNSContext) (ok bool, err error) {
			dctx.Res = buildResp(dctx.Req, nonOptimisticTTL)

			return true, nil
		},
		onCacheResp: func(dctx *DNSContext) {
			// Report adding to cache is in process.
			out <- unit{}
			// Wait for tests to finish.
			<-in

			p.cacheResp(dctx)

			// Report adding tocache is finished.
			out <- unit{}
		},
	}

	// Two different contexts are made to emulate two different requests
	// with the same question section.
	firstCtx, secondCtx := buildCtx(), buildCtx()

	// Add expired response into cache.
	req := firstCtx.Req
	key := msgToKey(req)
	data := (&cacheItem{
		m: buildResp(req, 0),
		u: testUpsAddr,
	}).pack()
	items := glcache.New(glcache.Config{
		EnableLRU: true,
	})
	items.Set(key, data)
	p.cache.items = items

	err := p.Resolve(firstCtx)
	require.NoError(t, err)
	require.Len(t, firstCtx.Res.Answer, 1)

	assert.EqualValues(t, optimisticTTL, firstCtx.Res.Answer[0].Header().Ttl)

	// Wait for optimisticResolver to reach the tested function.
	<-out

	err = p.Resolve(secondCtx)
	require.NoError(t, err)
	require.Len(t, secondCtx.Res.Answer, 1)

	assert.EqualValues(t, optimisticTTL, secondCtx.Res.Answer[0].Header().Ttl)

	// Continue and wait for it to finish.
	in <- unit{}
	<-out

	// Should be served from cache.
	data = p.cache.items.Get(msgToKey(firstCtx.Req))
	unpacked, expired := p.cache.unpackItem(data, firstCtx.Req)
	require.False(t, expired)
	require.NotNil(t, unpacked)
	require.Len(t, unpacked.m.Answer, 1)

	assert.EqualValues(t, nonOptimisticTTL, unpacked.m.Answer[0].Header().Ttl)
}

func TestProxy_HandleDNSRequest_private(t *testing.T) {
	t.Parallel()

	privateSet := netutil.SubnetSetFunc(netutil.IsLocallyServed)

	localIP := netip.MustParseAddrPort("192.168.0.1:1")
	require.True(t, privateSet.Contains(localIP.Addr()))

	externalIP := netip.MustParseAddrPort("4.3.2.1:1")
	require.False(t, privateSet.Contains(externalIP.Addr()))

	privateReq := (&dns.Msg{}).SetQuestion("2.0.168.192.in-addr.arpa", dns.TypePTR)
	privateResp := (&dns.Msg{}).SetReply(privateReq)
	privateResp.Compress = true

	externalReq := (&dns.Msg{}).SetQuestion("2.2.3.4.in-addr.arpa", dns.TypePTR)
	externalResp := (&dns.Msg{}).SetReply(externalReq)
	externalResp.Compress = true

	nxdomainResp := (&dns.Msg{}).SetReply(privateReq)
	nxdomainResp.Rcode = dns.RcodeNameError

	generalUps := &fakeUpstream{
		onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
			return externalResp.Copy(), nil
		},
		onAddress: func() (addr string) { return "general" },
		onClose:   func() (err error) { return nil },
	}
	privateUps := &fakeUpstream{
		onExchange: func(m *dns.Msg) (resp *dns.Msg, err error) {
			return privateResp.Copy(), nil
		},
		onAddress: func() (addr string) { return "private" },
		onClose:   func() (err error) { return nil },
	}

	messages := dnsproxytest.NewTestMessageConstructor()
	messages.OnNewMsgNXDOMAIN = func(_ *dns.Msg) (resp *dns.Msg) {
		return nxdomainResp
	}

	p := mustNew(t, &Config{
		Logger:        slogutil.NewDiscardLogger(),
		UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig: &UpstreamConfig{
			Upstreams: []upstream.Upstream{generalUps},
		},
		PrivateRDNSUpstreamConfig: &UpstreamConfig{
			Upstreams: []upstream.Upstream{privateUps},
		},
		PrivateSubnets:     privateSet,
		UsePrivateRDNS:     true,
		MessageConstructor: messages,
	})
	ctx := context.Background()
	require.NoError(t, p.Start(ctx))
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) })

	testCases := []struct {
		name    string
		want    *dns.Msg
		req     *dns.Msg
		cliAddr netip.AddrPort
	}{{
		name:    "local_requests_external",
		want:    externalResp,
		req:     externalReq,
		cliAddr: localIP,
	}, {
		name:    "external_requests_external",
		want:    externalResp,
		req:     externalReq,
		cliAddr: externalIP,
	}, {
		name:    "local_requests_private",
		want:    privateResp,
		req:     privateReq,
		cliAddr: localIP,
	}, {
		name:    "external_requests_private",
		want:    nxdomainResp,
		req:     privateReq,
		cliAddr: externalIP,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			dctx := p.newDNSContext(ProtoUDP, tc.req, tc.cliAddr)

			require.NoError(t, p.handleDNSRequest(dctx))
			assert.Equal(t, tc.want, dctx.Res)
		})
	}
}
0707010000006A000081A4000000000000000000000001679A649F00000ED3000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/proxycache.gopackage proxy

import (
	"net"
	"slices"
)

// cacheForContext returns cache object for the given context.
func (p *Proxy) cacheForContext(d *DNSContext) (c *cache) {
	if d.CustomUpstreamConfig != nil && d.CustomUpstreamConfig.cache != nil {
		return d.CustomUpstreamConfig.cache
	}

	return p.cache
}

// replyFromCache tries to get the response from general or subnet cache.  In
// case the cache is present in d, it's used first.  Returns true on success.
func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) {
	dctxCache := p.cacheForContext(d)

	var ci *cacheItem
	var cacheSource string
	var expired bool
	var key []byte

	// TODO(d.kolyshev): Use EnableEDNSClientSubnet from dctxCache.
	if p.Config.EnableEDNSClientSubnet && d.ReqECS != nil {
		ci, expired, key = dctxCache.getWithSubnet(d.Req, d.ReqECS)
		cacheSource = "subnet cache"
	} else {
		ci, expired, key = dctxCache.get(d.Req)
		cacheSource = "general cache"
	}

	if hit = ci != nil; !hit {
		return hit
	}

	d.Res = ci.m
	d.queryStatistics = cachedQueryStatistics(ci.u)

	p.logger.Debug(
		"replying from cache",
		"source", cacheSource,
		"ecs_enabled", p.Config.EnableEDNSClientSubnet,
	)

	if dctxCache.optimistic && expired {
		// Build a reduced clone of the current context to avoid data race.
		minCtxClone := &DNSContext{
			// It is only read inside the optimistic resolver.
			CustomUpstreamConfig: d.CustomUpstreamConfig,
			ReqECS:               cloneIPNet(d.ReqECS),
			IsPrivateClient:      d.IsPrivateClient,
		}
		if d.Req != nil {
			minCtxClone.Req = d.Req.Copy()
			addDO(minCtxClone.Req)
		}

		go p.shortFlighter.resolveOnce(minCtxClone, key, p.logger)
	}

	return hit
}

// cloneIPNet returns a deep clone of n.
func cloneIPNet(n *net.IPNet) (clone *net.IPNet) {
	if n == nil {
		return nil
	}

	return &net.IPNet{
		IP:   slices.Clone(n.IP),
		Mask: slices.Clone(n.Mask),
	}
}

// cacheResp stores the response from d in general or subnet cache.  In case the
// cache is present in d, it's used first.
func (p *Proxy) cacheResp(d *DNSContext) {
	dctxCache := p.cacheForContext(d)

	if !p.EnableEDNSClientSubnet {
		dctxCache.set(d.Res, d.Upstream, p.logger)

		return
	}

	switch ecs, scope := ecsFromMsg(d.Res); {
	case ecs != nil && d.ReqECS != nil:
		ones, bits := ecs.Mask.Size()
		reqOnes, _ := d.ReqECS.Mask.Size()

		// If FAMILY, SOURCE PREFIX-LENGTH, and SOURCE PREFIX-LENGTH bits of
		// ADDRESS in the response don't match the non-zero fields in the
		// corresponding query, the full response MUST be dropped.
		//
		// See RFC 7871 Section 7.3.
		//
		// TODO(a.meshkov):  The whole response MUST be dropped if ECS in it
		// doesn't correspond.
		if !ecs.IP.Mask(ecs.Mask).Equal(d.ReqECS.IP.Mask(d.ReqECS.Mask)) || ones != reqOnes {
			p.logger.Debug(
				"not caching response; subnet mismatch",
				"ecs", ecs,
				"req_ecs", d.ReqECS,
			)

			return
		}

		// If SCOPE PREFIX-LENGTH is not longer than SOURCE PREFIX-LENGTH, store
		// SCOPE PREFIX-LENGTH bits of ADDRESS, and then mark the response as
		// valid for all addresses that fall within that range.
		//
		// See RFC 7871 Section 7.3.1.
		if scope < reqOnes {
			ecs.Mask = net.CIDRMask(scope, bits)
			ecs.IP = ecs.IP.Mask(ecs.Mask)
		}

		p.logger.Debug("caching response", "ecs", ecs)

		dctxCache.setWithSubnet(d.Res, d.Upstream, ecs, p.logger)
	case d.ReqECS != nil:
		// Cache the response for all subnets since the server doesn't support
		// EDNS Client Subnet option.
		dctxCache.setWithSubnet(d.Res, d.Upstream, &net.IPNet{IP: nil, Mask: nil}, p.logger)
	default:
		dctxCache.set(d.Res, d.Upstream, p.logger)
	}
}

// ClearCache clears the DNS cache of p.
func (p *Proxy) ClearCache() {
	if p.cache == nil {
		return
	}

	p.cache.clearItems()
	p.cache.clearItemsWithSubnet()
	p.logger.Debug("cache cleared")
}
0707010000006B000081A4000000000000000000000001679A649F000005CA000000000000000000000000000000000000002300000000dnsproxy-0.75.0/proxy/ratelimit.gopackage proxy

import (
	"fmt"
	"net/netip"
	"slices"
	"time"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	rate "github.com/beefsack/go-rate"
	gocache "github.com/patrickmn/go-cache"
)

func (p *Proxy) limiterForIP(ip string) interface{} {
	p.ratelimitLock.Lock()
	defer p.ratelimitLock.Unlock()
	if p.ratelimitBuckets == nil {
		p.ratelimitBuckets = gocache.New(time.Hour, time.Hour)
	}

	// check if ratelimiter for that IP already exists, if not, create
	value, found := p.ratelimitBuckets.Get(ip)
	if !found {
		value = rate.New(p.Ratelimit, time.Second)
		p.ratelimitBuckets.Set(ip, value, time.Hour)
	}

	return value
}

func (p *Proxy) isRatelimited(addr netip.Addr) (ok bool) {
	if p.Ratelimit <= 0 {
		// The ratelimit is disabled.
		return false
	}

	addr = addr.Unmap()
	// Already sorted by [Proxy.Init].
	_, ok = slices.BinarySearchFunc(p.RatelimitWhitelist, addr, netip.Addr.Compare)
	if ok {
		return false
	}

	var pref netip.Prefix
	if addr.Is4() {
		pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv4)
	} else {
		pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv6)
	}
	pref = pref.Masked()

	// TODO(s.chzhen):  Improve caching.  Decrease allocations.
	ipStr := pref.Addr().String()
	value := p.limiterForIP(ipStr)
	rl, ok := value.(*rate.RateLimiter)
	if !ok {
		p.logger.Error(
			"invalid value found in ratelimit cache",
			slogutil.KeyError,
			fmt.Errorf("bad type %T", value),
		)

		return false
	}

	allow, _ := rl.Try()

	return !allow
}
0707010000006C000081A4000000000000000000000001679A649F00000959000000000000000000000000000000000000002800000000dnsproxy-0.75.0/proxy/ratelimit_test.gopackage proxy

import (
	"context"
	"net"
	"net/netip"
	"testing"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/require"
)

func TestRatelimitingProxy(t *testing.T) {
	dnsProxy := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		Ratelimit:              1,
	})

	// Start listening
	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{
		Net:     string(ProtoUDP),
		Timeout: testTimeout,
	}

	// Send the first message (not blocked)
	req := newTestMessage()

	r, _, err := client.Exchange(req, addr.String())
	if err != nil {
		t.Fatalf("error in the first request: %s", err)
	}
	requireResponse(t, req, r)

	// Send the second message (blocked)
	req = newTestMessage()

	_, _, err = client.Exchange(req, addr.String())
	if err == nil {
		t.Fatalf("second request was not blocked")
	}
}

func TestRatelimiting(t *testing.T) {
	// rate limit is 1 per sec
	p := Proxy{}
	p.Ratelimit = 1

	addr := netip.MustParseAddr("127.0.0.1")

	limited := p.isRatelimited(addr)

	if limited {
		t.Fatal("First request must have been allowed")
	}

	limited = p.isRatelimited(addr)

	if !limited {
		t.Fatal("Second request must have been ratelimited")
	}
}

func TestWhitelist(t *testing.T) {
	// rate limit is 1 per sec with whitelist
	p := Proxy{}
	p.Ratelimit = 1
	p.RatelimitWhitelist = []netip.Addr{
		netip.MustParseAddr("127.0.0.1"),
		netip.MustParseAddr("127.0.0.2"),
		netip.MustParseAddr("127.0.0.125"),
	}

	addr := netip.MustParseAddr("127.0.0.1")

	limited := p.isRatelimited(addr)

	if limited {
		t.Fatal("First request must have been allowed")
	}

	limited = p.isRatelimited(addr)

	if limited {
		t.Fatal("Second request must have been allowed due to whitelist")
	}
}
0707010000006D000081A4000000000000000000000001679A649F000009A9000000000000000000000000000000000000002B00000000dnsproxy-0.75.0/proxy/recursiondetector.gopackage proxy

import (
	"encoding/binary"
	"time"

	glcache "github.com/AdguardTeam/golibs/cache"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

// uint* sizes in bytes to improve readability.
//
// TODO(e.burkov): Remove when there will be a more regardful way to define
// those.  See https://github.com/golang/go/issues/29982.
const (
	uint16sz = 2
	uint64sz = 8
)

// TODO(e.burkov):  Consider making configurable.
const (
	// recursionTTL is the time recursive request is cached for.
	recursionTTL = 1 * time.Second

	// cachedRecurrentReqNum is the maximum number of cached recurrent requests.
	cachedRecurrentReqNum = 1000
)

// recursionDetector detects recursion in DNS forwarding.
type recursionDetector struct {
	recentRequests glcache.Cache
	ttl            time.Duration
}

// check checks if the passed req was already sent by the server.
func (rd *recursionDetector) check(msg *dns.Msg) (ok bool) {
	if len(msg.Question) == 0 {
		return false
	}

	key := msgToSignature(msg)
	expireData := rd.recentRequests.Get(key)
	if expireData == nil {
		return false
	}

	expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))

	return time.Now().Before(expire)
}

// add caches the msg if it has anything in the questions section.
func (rd *recursionDetector) add(msg *dns.Msg) {
	now := time.Now()

	if len(msg.Question) == 0 {
		return
	}

	key := msgToSignature(msg)
	expire64 := uint64(now.Add(rd.ttl).UnixNano())
	expire := make([]byte, uint64sz)
	binary.BigEndian.PutUint64(expire, expire64)

	rd.recentRequests.Set(key, expire)
}

// clear clears the recent requests cache.
func (rd *recursionDetector) clear() {
	rd.recentRequests.Clear()
}

// newRecursionDetector returns the initialized *recursionDetector.
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
	return &recursionDetector{
		recentRequests: glcache.New(glcache.Config{
			EnableLRU: true,
			MaxCount:  suspectsNum,
		}),
		ttl: ttl,
	}
}

// msgToSignature converts msg into it's signature represented in bytes.
func msgToSignature(msg *dns.Msg) (sig []byte) {
	sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
	// The binary.BigEndian byte order is used everywhere except when the real
	// machine's endianness is needed.
	byteOrder := binary.BigEndian
	byteOrder.PutUint16(sig[0:], msg.Id)
	q := msg.Question[0]
	byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
	copy(sig[2*uint16sz:], []byte(q.Name))

	return sig
}
0707010000006E000081A4000000000000000000000001679A649F00000EDC000000000000000000000000000000000000003900000000dnsproxy-0.75.0/proxy/recursiondetector_internal_test.gopackage proxy

import (
	"bytes"
	"encoding/binary"
	"log/slog"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
)

func TestRecursionDetector_Check(t *testing.T) {
	rd := newRecursionDetector(0, 2)

	const (
		recID  = 1234
		recTTL = time.Hour * 1
	)

	const nonRecID = recID * 2

	sampleQuestion := dns.Question{
		Name:  "some.domain",
		Qtype: dns.TypeAAAA,
	}
	sampleMsg := &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id: recID,
		},
		Question: []dns.Question{sampleQuestion},
	}

	// Manually add the message with big ttl.
	key := msgToSignature(sampleMsg)
	expire := make([]byte, uint64sz)
	binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
	rd.recentRequests.Set(key, expire)

	// Add an expired message.
	sampleMsg.Id = nonRecID
	rd.add(sampleMsg)

	testCases := []struct {
		name      string
		questions []dns.Question
		id        uint16
		want      bool
	}{{
		name:      "recurrent",
		questions: []dns.Question{sampleQuestion},
		id:        recID,
		want:      true,
	}, {
		name:      "not_suspected",
		questions: []dns.Question{sampleQuestion},
		id:        recID + 1,
		want:      false,
	}, {
		name:      "expired",
		questions: []dns.Question{sampleQuestion},
		id:        nonRecID,
		want:      false,
	}, {
		name:      "empty",
		questions: []dns.Question{},
		id:        nonRecID,
		want:      false,
	}}

	for _, tc := range testCases {
		sampleMsg.Id = tc.id
		sampleMsg.Question = tc.questions
		t.Run(tc.name, func(t *testing.T) {
			detected := rd.check(sampleMsg)
			assert.Equal(t, tc.want, detected)
		})
	}
}

func TestRecursionDetector_Suspect(t *testing.T) {
	rd := newRecursionDetector(0, 1)

	testCases := []struct {
		msg  *dns.Msg
		name string
		want int
	}{{
		msg: &dns.Msg{
			MsgHdr: dns.MsgHdr{
				Id: 1234,
			},
			Question: []dns.Question{{
				Name:  "some.domain",
				Qtype: dns.TypeA,
			}},
		},
		name: "simple",
		want: 1,
	}, {
		msg:  &dns.Msg{},
		name: "unencumbered",
		want: 0,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Cleanup(rd.clear)
			rd.add(tc.msg)
			assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
		})
	}
}

// byteSink is a typed sink for benchmark results.
var byteSink []byte

func BenchmarkMsgToSignature(b *testing.B) {
	const name = "some.not.very.long.host.name"

	msg := &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id: 1234,
		},
		Question: []dns.Question{{
			Name:  name,
			Qtype: dns.TypeAAAA,
		}},
	}

	b.Run("efficient", func(b *testing.B) {
		b.ReportAllocs()

		for range b.N {
			byteSink = msgToSignature(msg)
		}

		assert.NotEmpty(b, byteSink)
	})

	b.Run("inefficient", func(b *testing.B) {
		b.ReportAllocs()

		for range b.N {
			byteSink = msgToSignatureSlow(msg)
		}

		assert.NotEmpty(b, byteSink)
	})

	// goos: darwin
	// goarch: amd64
	// pkg: github.com/AdguardTeam/dnsproxy/proxy
	// cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
	// BenchmarkMsgToSignature/efficient-12		17155314	68.84 ns/op		288 B/op	1 allocs/op
	// BenchmarkMsgToSignature/inefficient-12	460803		2367 ns/op		648 B/op	6 allocs/op
}

// msgToSignatureSlow converts msg into it's signature represented in bytes in
// the less efficient way.
//
// See BenchmarkMsgToSignature.
func msgToSignatureSlow(msg *dns.Msg) (sig []byte) {
	type msgSignature struct {
		name  [netutil.MaxDomainNameLen]byte
		id    uint16
		qtype uint16
	}

	b := bytes.NewBuffer(sig)
	q := msg.Question[0]
	signature := msgSignature{
		id:    msg.Id,
		qtype: q.Qtype,
	}
	copy(signature.name[:], q.Name)
	if err := binary.Write(b, binary.BigEndian, signature); err != nil {
		slog.Default().Debug("writing message signature", slogutil.KeyError, err)
	}

	return b.Bytes()
}
0707010000006F000081A4000000000000000000000001679A649F00001AE0000000000000000000000000000000000000002000000000dnsproxy-0.75.0/proxy/server.gopackage proxy

import (
	"context"
	"fmt"
	"io"
	"log/slog"
	"net"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
)

// configureListeners configures listeners.
func (p *Proxy) configureListeners(ctx context.Context) (err error) {
	err = p.createUDPListeners(ctx)
	if err != nil {
		return err
	}

	err = p.createTCPListeners(ctx)
	if err != nil {
		return err
	}

	err = p.createTLSListeners()
	if err != nil {
		return err
	}

	err = p.createHTTPSListeners()
	if err != nil {
		return err
	}

	err = p.createQUICListeners()
	if err != nil {
		return err
	}

	err = p.createDNSCryptListeners()
	if err != nil {
		return err
	}

	return nil
}

// startListeners starts listener loops.
func (p *Proxy) startListeners() {
	for _, l := range p.udpListen {
		go p.udpPacketLoop(l, p.requestsSema)
	}

	for _, l := range p.tcpListen {
		go p.tcpPacketLoop(l, ProtoTCP, p.requestsSema)
	}

	for _, l := range p.tlsListen {
		go p.tcpPacketLoop(l, ProtoTLS, p.requestsSema)
	}

	for _, l := range p.httpsListen {
		go func(l net.Listener) { _ = p.httpsServer.Serve(l) }(l)
	}

	for _, l := range p.h3Listen {
		go func(l *quic.EarlyListener) { _ = p.h3Server.ServeListener(l) }(l)
	}

	for _, l := range p.quicListen {
		go p.quicPacketLoop(l, p.requestsSema)
	}

	for _, l := range p.dnsCryptUDPListen {
		go func(l *net.UDPConn) { _ = p.dnsCryptServer.ServeUDP(l) }(l)
	}

	for _, l := range p.dnsCryptTCPListen {
		go func(l net.Listener) { _ = p.dnsCryptServer.ServeTCP(l) }(l)
	}
}

// handleDNSRequest processes the context.  The only error it returns is the one
// from the [RequestHandler], or [Resolve] if the [RequestHandler] is not set.
// d is left without a response as the documentation to [BeforeRequestHandler]
// says, and if it's ratelimited.
func (p *Proxy) handleDNSRequest(d *DNSContext) (err error) {
	p.logDNSMessage(d.Req)

	if d.Req.Response {
		p.logger.Debug("dropping incoming response packet", "addr", d.Addr)

		return nil
	}

	ip := d.Addr.Addr()
	d.IsPrivateClient = p.privateNets.Contains(ip)

	if !p.handleBefore(d) {
		return nil
	}

	// ratelimit based on IP only, protects CPU cycles and outbound connections
	//
	// TODO(e.burkov):  Investigate if written above true and move to UDP server
	// implementation?
	if d.Proto == ProtoUDP && p.isRatelimited(ip) {
		p.logger.Debug("ratelimited based on ip only", "addr", d.Addr)

		// Don't reply to ratelimited clients.
		return nil
	}

	d.Res = p.validateRequest(d)
	if d.Res == nil {
		if p.RequestHandler != nil {
			err = errors.Annotate(p.RequestHandler(p, d), "using request handler: %w")
		} else {
			err = errors.Annotate(p.Resolve(d), "using default request handler: %w")
		}
	}

	p.logDNSMessage(d.Res)
	p.respond(d)

	return err
}

// validateRequest returns a response for invalid request or nil if the request
// is ok.
func (p *Proxy) validateRequest(d *DNSContext) (resp *dns.Msg) {
	switch {
	case len(d.Req.Question) != 1:
		p.logger.Debug("invalid number of questions", "req_questions_len", len(d.Req.Question))

		// TODO(e.burkov):  Probably, FORMERR would be a better choice here.
		// Check out RFC.
		return p.messages.NewMsgSERVFAIL(d.Req)
	case p.RefuseAny && d.Req.Question[0].Qtype == dns.TypeANY:
		// Refuse requests of type ANY (anti-DDOS measure).
		p.logger.Debug("refusing dns type any request")

		return p.messages.NewMsgNOTIMPLEMENTED(d.Req)
	case p.recDetector.check(d.Req):
		p.logger.Debug("recursion detected", "req_question", d.Req.Question[0].Name)

		return p.messages.NewMsgNXDOMAIN(d.Req)
	case d.isForbiddenARPA(p.privateNets, p.logger):
		p.logger.Debug(
			"private arpa domain is requested",
			"addr", d.Addr,
			"arpa", d.Req.Question[0].Name,
		)

		return p.messages.NewMsgNXDOMAIN(d.Req)
	default:
		return nil
	}
}

// isForbiddenARPA returns true if dctx contains a PTR, SOA, or NS request for
// some private address and client's address is not within the private network.
// Otherwise, it sets [DNSContext.RequestedPrivateRDNS] for future use.
func (dctx *DNSContext) isForbiddenARPA(privateNets netutil.SubnetSet, l *slog.Logger) (ok bool) {
	q := dctx.Req.Question[0]
	switch q.Qtype {
	case dns.TypePTR, dns.TypeSOA, dns.TypeNS:
		// Go on.
		//
		// TODO(e.burkov):  Reconsider the list of types involved to private
		// address space.  Perhaps, use the logic for any type.  See
		// https://www.rfc-editor.org/rfc/rfc6761.html#section-6.1.
	default:
		return false
	}

	requestedPref, err := netutil.ExtractReversedAddr(q.Name)
	if err != nil {
		l.Debug("parsing reversed subnet", slogutil.KeyError, err)

		return false
	}

	if privateNets.Contains(requestedPref.Addr()) {
		dctx.RequestedPrivateRDNS = requestedPref

		return !dctx.IsPrivateClient
	}

	return false
}

// respond writes the specified response to the client (or does nothing if d.Res is empty)
func (p *Proxy) respond(d *DNSContext) {
	// d.Conn can be nil in the case of a DoH request.
	if d.Conn != nil {
		_ = d.Conn.SetWriteDeadline(time.Now().Add(defaultTimeout))
	}

	var err error

	switch d.Proto {
	case ProtoUDP:
		err = p.respondUDP(d)
	case ProtoTCP:
		err = p.respondTCP(d)
	case ProtoTLS:
		err = p.respondTCP(d)
	case ProtoHTTPS:
		err = p.respondHTTPS(d)
	case ProtoQUIC:
		err = p.respondQUIC(d)
	case ProtoDNSCrypt:
		err = p.respondDNSCrypt(d)
	default:
		err = fmt.Errorf("SHOULD NOT HAPPEN - unknown protocol: %s", d.Proto)
	}

	if err != nil {
		logWithNonCrit(err, "responding request", d.Proto, p.logger)
	}
}

// Set TTL value of all records according to our settings
func (p *Proxy) setMinMaxTTL(r *dns.Msg) {
	for _, rr := range r.Answer {
		originalTTL := rr.Header().Ttl
		newTTL := respectTTLOverrides(originalTTL, p.CacheMinTTL, p.CacheMaxTTL)

		if originalTTL != newTTL {
			p.logger.Debug("ttl overwritten", "old", originalTTL, "new", newTTL)
			rr.Header().Ttl = newTTL
		}
	}
}

// logDNSMessage logs the given DNS message.
func (p *Proxy) logDNSMessage(m *dns.Msg) {
	if m == nil {
		return
	}

	var msg string
	if m.Response {
		msg = "out"
	} else {
		msg = "in"
	}

	slogutil.PrintLines(context.TODO(), p.logger, slog.LevelDebug, msg, m.String())
}

// logWithNonCrit logs the error on the appropriate level depending on whether
// err is a critical error or not.
func logWithNonCrit(err error, msg string, proto Proto, l *slog.Logger) {
	if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || isEPIPE(err) {
		l.Debug(
			"connection is closed",
			"proto", proto,
			"details", msg,
			slogutil.KeyError, err,
		)
	} else if netErr := net.Error(nil); errors.As(err, &netErr) && netErr.Timeout() {
		l.Debug(
			"connection timed out",
			"proto", proto,
			"details", msg,
			slogutil.KeyError, err,
		)
	} else {
		l.Error(msg, "proto", proto, slogutil.KeyError, err)
	}
}
07070100000070000081A4000000000000000000000001679A649F00000A9B000000000000000000000000000000000000002900000000dnsproxy-0.75.0/proxy/server_dnscrypt.gopackage proxy

import (
	"context"
	"fmt"
	"net"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/syncutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/miekg/dns"
)

func (p *Proxy) createDNSCryptListeners() (err error) {
	if len(p.DNSCryptUDPListenAddr) == 0 && len(p.DNSCryptTCPListenAddr) == 0 {
		// Do nothing if DNSCrypt listen addresses are not specified.
		return nil
	}

	if p.DNSCryptResolverCert == nil || p.DNSCryptProviderName == "" {
		return errors.Error("invalid dnscrypt configuration: no certificate or provider name")
	}

	p.logger.Info("initializing dnscrypt", "provider", p.DNSCryptProviderName)
	p.dnsCryptServer = &dnscrypt.Server{
		ProviderName: p.DNSCryptProviderName,
		ResolverCert: p.DNSCryptResolverCert,
		Handler: &dnsCryptHandler{
			proxy: p,

			reqSema: p.requestsSema,
		},
	}

	for _, a := range p.DNSCryptUDPListenAddr {
		p.logger.Info("creating dnscrypt udp listener")
		udpListen, lErr := net.ListenUDP(bootstrap.NetworkUDP, a)
		if lErr != nil {
			return fmt.Errorf("listening to dnscrypt udp socket: %w", lErr)
		}

		p.dnsCryptUDPListen = append(p.dnsCryptUDPListen, udpListen)
		p.logger.Info("listening for dnscrypt messages on udp", "addr", udpListen.LocalAddr())
	}

	for _, a := range p.DNSCryptTCPListenAddr {
		p.logger.Info("creating a dnscrypt tcp listener")
		tcpListen, lErr := net.ListenTCP(bootstrap.NetworkTCP, a)
		if lErr != nil {
			return fmt.Errorf("listening to dnscrypt tcp socket: %w", lErr)
		}

		p.dnsCryptTCPListen = append(p.dnsCryptTCPListen, tcpListen)
		p.logger.Info("listening for dnscrypt messages on tcp", "addr", tcpListen.Addr())
	}

	return nil
}

// dnsCryptHandler - dnscrypt.Handler implementation
type dnsCryptHandler struct {
	proxy *Proxy

	reqSema syncutil.Semaphore
}

// compile-time type check
var _ dnscrypt.Handler = &dnsCryptHandler{}

// ServeDNS - processes the DNS query
func (h *dnsCryptHandler) ServeDNS(rw dnscrypt.ResponseWriter, req *dns.Msg) (err error) {
	d := h.proxy.newDNSContext(ProtoDNSCrypt, req, netutil.NetAddrToAddrPort(rw.RemoteAddr()))
	d.DNSCryptResponseWriter = rw

	// TODO(d.kolyshev): Pass and use context from above.
	err = h.reqSema.Acquire(context.Background())
	if err != nil {
		return fmt.Errorf("dnsproxy: dnscrypt: acquiring semaphore: %w", err)
	}
	defer h.reqSema.Release()

	return h.proxy.handleDNSRequest(d)
}

// Writes a response to the UDP client
func (p *Proxy) respondDNSCrypt(d *DNSContext) error {
	if d.Res == nil {
		// If no response has been written, do nothing and let it drop
		return nil
	}

	return d.DNSCryptResponseWriter.WriteMsg(d.Res)
}
07070100000071000081A4000000000000000000000001679A649F00000A4C000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/proxy/server_dnscrypt_test.gopackage proxy

import (
	"context"
	"fmt"
	"net"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/ameshkov/dnsstamps"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// TODO(d.kolyshev): Remove this after migrating dnscrypt to slog.
func TestMain(m *testing.M) {
	testutil.DiscardLogOutput(m)
}

func getFreePort() uint {
	l, _ := net.Listen("tcp", "127.0.0.1:0")
	port := uint(l.Addr().(*net.TCPAddr).Port)

	// stop listening immediately
	_ = l.Close()

	// sleep for 100ms (may be necessary on Windows)
	time.Sleep(100 * time.Millisecond)
	return port
}

func createTestDNSCryptProxy(t *testing.T) (*Proxy, dnscrypt.ResolverConfig) {
	rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
	assert.NoError(t, err)

	cert, err := rc.CreateCert()
	assert.NoError(t, err)

	port := getFreePort()
	p := mustNew(t, &Config{
		Logger: slogutil.NewDiscardLogger(),
		DNSCryptUDPListenAddr: []*net.UDPAddr{{
			Port: int(port), IP: net.ParseIP(listenIP),
		}},
		DNSCryptTCPListenAddr: []*net.TCPAddr{{
			Port: int(port), IP: net.ParseIP(listenIP),
		}},
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		EnableEDNSClientSubnet: true,
		CacheEnabled:           true,
		CacheMinTTL:            20,
		CacheMaxTTL:            40,
		DNSCryptProviderName:   rc.ProviderName,
		DNSCryptResolverCert:   cert,
	})

	return p, rc
}

func TestDNSCryptProxy(t *testing.T) {
	// Prepare the proxy server
	dnsProxy, rc := createTestDNSCryptProxy(t)

	// Start listening
	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	// Generate a DNS stamp
	addr := fmt.Sprintf("%s:%d", listenIP, dnsProxy.Addr(ProtoDNSCrypt).(*net.UDPAddr).Port)
	stamp, err := rc.CreateStamp(addr)
	assert.Nil(t, err)

	// Test DNSCrypt proxy on both UDP and TCP
	checkDNSCryptProxy(t, "udp", stamp)
	checkDNSCryptProxy(t, "tcp", stamp)
}

func checkDNSCryptProxy(t *testing.T, proto string, stamp dnsstamps.ServerStamp) {
	// Create a DNSCrypt client
	c := &dnscrypt.Client{
		Timeout: defaultTimeout,
		Net:     proto,
	}

	// Fetch the server certificate
	ri, err := c.DialStamp(stamp)
	assert.Nil(t, err)

	// Send the test message
	msg := newTestMessage()
	reply, err := c.Exchange(msg, ri)
	assert.Nil(t, err)
	requireResponse(t, msg, reply)
}
07070100000072000081A4000000000000000000000001679A649F000021C5000000000000000000000000000000000000002600000000dnsproxy-0.75.0/proxy/server_https.gopackage proxy

import (
	"context"
	"crypto/tls"
	"encoding/base64"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/http"
	"net/netip"
	"net/url"
	"strings"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/httphdr"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
	"golang.org/x/net/http2"
)

// listenHTTP creates instances of TLS listeners that will be used to run an
// H1/H2 server.  Returns the address the listener actually listens to (useful
// in the case if port 0 is specified).
func (p *Proxy) listenHTTP(addr *net.TCPAddr) (laddr *net.TCPAddr, err error) {
	tcpListen, err := net.ListenTCP(bootstrap.NetworkTCP, addr)
	if err != nil {
		return nil, fmt.Errorf("tcp listener: %w", err)
	}

	p.logger.Info("listening to https", "addr", tcpListen.Addr())

	tlsConfig := p.TLSConfig.Clone()
	tlsConfig.NextProtos = []string{http2.NextProtoTLS, "http/1.1"}

	tlsListen := tls.NewListener(tcpListen, tlsConfig)
	p.httpsListen = append(p.httpsListen, tlsListen)

	return tcpListen.Addr().(*net.TCPAddr), nil
}

// listenH3 creates instances of QUIC listeners that will be used for running
// an HTTP/3 server.
func (p *Proxy) listenH3(addr *net.UDPAddr) (err error) {
	tlsConfig := p.TLSConfig.Clone()
	tlsConfig.NextProtos = []string{"h3"}
	quicListen, err := quic.ListenAddrEarly(addr.String(), tlsConfig, newServerQUICConfig())
	if err != nil {
		return fmt.Errorf("quic listener: %w", err)
	}

	p.logger.Info("listening to h3", "addr", quicListen.Addr())

	p.h3Listen = append(p.h3Listen, quicListen)

	return nil
}

// createHTTPSListeners creates TCP/UDP listeners and HTTP/H3 servers.
func (p *Proxy) createHTTPSListeners() (err error) {
	p.httpsServer = &http.Server{
		Handler:           p,
		ReadHeaderTimeout: defaultTimeout,
		WriteTimeout:      defaultTimeout,
	}

	if p.HTTP3 {
		p.h3Server = &http3.Server{
			Handler: p,
		}
	}

	for _, addr := range p.HTTPSListenAddr {
		p.logger.Info("creating an https server")

		tcpAddr, lErr := p.listenHTTP(addr)
		if lErr != nil {
			return fmt.Errorf("failed to start HTTPS server on %s: %w", addr, lErr)
		}

		if p.HTTP3 {
			// HTTP/3 server listens to the same pair IP:port as the one HTTP/2
			// server listens to.
			udpAddr := &net.UDPAddr{IP: tcpAddr.IP, Port: tcpAddr.Port}
			err = p.listenH3(udpAddr)
			if err != nil {
				return fmt.Errorf("failed to start HTTP/3 server on %s: %w", udpAddr, err)
			}
		}
	}

	return nil
}

// newDoHReq returns new DNS request parsed from the given HTTP request.  In
// case of invalid request returns nil and the suitable status code for an HTTP
// error response.  l must not be nil.
func newDoHReq(r *http.Request, l *slog.Logger) (req *dns.Msg, statusCode int) {
	var buf []byte
	var err error

	switch r.Method {
	case http.MethodGet:
		dnsParam := r.URL.Query().Get("dns")
		buf, err = base64.RawURLEncoding.DecodeString(dnsParam)
		if len(buf) == 0 || err != nil {
			l.Debug(
				"parsing dns request from http get param",
				"param_name", dnsParam,
				slogutil.KeyError, err,
			)

			return nil, http.StatusBadRequest
		}
	case http.MethodPost:
		contentType := r.Header.Get(httphdr.ContentType)
		if contentType != "application/dns-message" {
			l.Debug("unsupported media type", "content_type", contentType)

			return nil, http.StatusUnsupportedMediaType
		}

		// TODO(d.kolyshev): Limit reader.
		buf, err = io.ReadAll(r.Body)
		if err != nil {
			l.Debug("reading http request body", slogutil.KeyError, err)

			return nil, http.StatusBadRequest
		}

		defer slogutil.CloseAndLog(context.TODO(), l, r.Body, slog.LevelDebug)
	default:
		l.Debug("bad http method", "method", r.Method)

		return nil, http.StatusMethodNotAllowed
	}

	req = &dns.Msg{}
	if err = req.Unpack(buf); err != nil {
		l.Debug("unpacking http msg", slogutil.KeyError, err)

		return nil, http.StatusBadRequest
	}

	return req, http.StatusOK
}

// ServeHTTP is the http.Handler implementation that handles DoH queries.
//
// Here is what it returns:
//
//   - http.StatusBadRequest if there is no DNS request data,
//   - http.StatusUnsupportedMediaType if request content type is not
//     "application/dns-message",
//   - http.StatusMethodNotAllowed if request method is not GET or POST.
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	p.logger.Debug("incoming https request", "url", r.URL)

	raddr, prx, err := remoteAddr(r, p.logger)
	if err != nil {
		p.logger.Debug("getting real ip", slogutil.KeyError, err)
	}

	if !p.checkBasicAuth(w, r, raddr) {
		return
	}

	req, statusCode := newDoHReq(r, p.logger)
	if req == nil {
		http.Error(w, http.StatusText(statusCode), statusCode)

		return
	}

	if prx.IsValid() {
		p.logger.Debug("request came from proxy server", "addr", prx)

		if !p.TrustedProxies.Contains(prx.Addr()) {
			p.logger.Debug("proxy is not trusted, using original remote addr", "addr", prx)

			// So the address of the proxy itself is used, as the remote address
			// parsed from headers cannot be trusted.
			//
			// TODO(e.burkov): Do not parse headers in this case.
			raddr = prx
		}
	}

	d := p.newDNSContext(ProtoHTTPS, req, raddr)
	d.HTTPRequest = r
	d.HTTPResponseWriter = w

	err = p.handleDNSRequest(d)
	if err != nil {
		p.logger.Debug("handling dns request", "proto", d.Proto, slogutil.KeyError, err)
	}
}

// checkBasicAuth checks the basic authorization data, if necessary, and if the
// data isn't valid, it writes an error.  shouldHandle is false if the request
// has been denied.
func (p *Proxy) checkBasicAuth(
	w http.ResponseWriter,
	r *http.Request,
	raddr netip.AddrPort,
) (shouldHandle bool) {
	ui := p.Config.Userinfo
	if ui == nil {
		return true
	}

	user, pass, _ := r.BasicAuth()
	if matchesUserinfo(ui, user, pass) {
		return true
	}

	p.logger.Error("basic auth failed", "user", user, "raddr", raddr)

	h := w.Header()
	h.Set(httphdr.WWWAuthenticate, `Basic realm="DNS", charset="UTF-8"`)
	http.Error(w, "Authorization required", http.StatusUnauthorized)

	return false
}

// matchesUserinfo returns false if user and pass don't match userinfo.
// userinfo must not be nil.
func matchesUserinfo(userinfo *url.Userinfo, user, pass string) (ok bool) {
	requiredPassword, _ := userinfo.Password()

	return user == userinfo.Username() && pass == requiredPassword
}

// Writes a response to the DoH client.
func (p *Proxy) respondHTTPS(d *DNSContext) (err error) {
	resp := d.Res
	w := d.HTTPResponseWriter

	if resp == nil {
		// Indicate the response's absence via a http.StatusInternalServerError.
		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

		return nil
	}

	bytes, err := resp.Pack()
	if err != nil {
		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

		return fmt.Errorf("packing message: %w", err)
	}

	if srvName := p.Config.HTTPSServerName; srvName != "" {
		w.Header().Set(httphdr.Server, srvName)
	}

	w.Header().Set(httphdr.ContentType, "application/dns-message")
	_, err = w.Write(bytes)

	return err
}

// realIPFromHdrs extracts the actual client's IP address from the first
// suitable r's header.  It returns an error if r doesn't contain any
// information about real client's IP address.  Current headers priority is:
//
//  1. [httphdr.CFConnectingIP]
//  2. [httphdr.TrueClientIP]
//  3. [httphdr.XRealIP]
//  4. [httphdr.XForwardedFor]
func realIPFromHdrs(r *http.Request) (realIP netip.Addr, err error) {
	for _, h := range []string{
		httphdr.CFConnectingIP,
		httphdr.TrueClientIP,
		httphdr.XRealIP,
	} {
		realIP, err = netip.ParseAddr(strings.TrimSpace(r.Header.Get(h)))
		if err == nil {
			return realIP, nil
		}
	}

	xff := r.Header.Get(httphdr.XForwardedFor)
	firstComma := strings.IndexByte(xff, ',')
	if firstComma > 0 {
		xff = xff[:firstComma]
	}

	return netip.ParseAddr(strings.TrimSpace(xff))
}

// remoteAddr returns the real client's address and the IP address of the latest
// proxy server if any.
func remoteAddr(r *http.Request, l *slog.Logger) (addr, prx netip.AddrPort, err error) {
	host, err := netip.ParseAddrPort(r.RemoteAddr)
	if err != nil {
		return netip.AddrPort{}, netip.AddrPort{}, err
	}

	realIP, err := realIPFromHdrs(r)
	if err != nil {
		l.Debug("getting ip address from http request", slogutil.KeyError, err)

		return host, netip.AddrPort{}, nil
	}

	l.Debug("using ip address from http request", "addr", realIP)

	// TODO(a.garipov): Add port if we can get it from headers like X-Real-Port,
	// X-Forwarded-Port, etc.
	addr = netip.AddrPortFrom(realIP, 0)

	return addr, host, nil
}
07070100000073000081A4000000000000000000000001679A649F00002F4C000000000000000000000000000000000000002B00000000dnsproxy-0.75.0/proxy/server_https_test.gopackage proxy

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/base64"
	"fmt"
	"io"
	"net"
	"net/http"
	"net/netip"
	"net/url"
	"strings"
	"testing"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestHttpsProxy(t *testing.T) {
	testCases := []struct {
		name  string
		http3 bool
	}{{
		name:  "https_proxy",
		http3: false,
	}, {
		name:  "h3_proxy",
		http3: true,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			tlsConf, caPem := newTLSConfig(t)
			dnsProxy := mustNew(t, &Config{
				Logger:                 slogutil.NewDiscardLogger(),
				TLSListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
				HTTPSListenAddr:        []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
				QUICListenAddr:         []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
				TLSConfig:              tlsConf,
				UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
				TrustedProxies:         defaultTrustedProxies,
				RatelimitSubnetLenIPv4: 24,
				RatelimitSubnetLenIPv6: 64,
				HTTP3:                  tc.http3,
			})

			// Run the proxy.
			ctx := context.Background()
			err := dnsProxy.Start(ctx)
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

			// Create the HTTP client that we'll be using for this test.
			client := createTestHTTPClient(dnsProxy, caPem, tc.http3)

			// Prepare a test message to be sent to the server.
			msg := newTestMessage()

			// Send the test message and check if the response is what we
			// expected.
			resp := sendTestDoHMessage(t, client, msg, nil)
			requireResponse(t, msg, resp)
		})
	}
}

func TestProxy_trustedProxies(t *testing.T) {
	var (
		clientAddr = netip.MustParseAddr("1.2.3.4")
		proxyAddr  = netip.MustParseAddr("127.0.0.1")
	)

	doRequest := func(t *testing.T, addr, expectedClientIP netip.Addr) {
		// Prepare the proxy server.
		tlsConf, caPem := newTLSConfig(t)
		dnsProxy := mustNew(t, &Config{
			Logger:                 slogutil.NewDiscardLogger(),
			TLSListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
			HTTPSListenAddr:        []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
			QUICListenAddr:         []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
			TLSConfig:              tlsConf,
			UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
			TrustedProxies:         defaultTrustedProxies,
			RatelimitSubnetLenIPv4: 24,
			RatelimitSubnetLenIPv6: 64,
		})

		var gotAddr netip.Addr
		dnsProxy.RequestHandler = func(_ *Proxy, d *DNSContext) (err error) {
			gotAddr = d.Addr.Addr()

			return dnsProxy.Resolve(d)
		}

		client := createTestHTTPClient(dnsProxy, caPem, false)

		msg := newTestMessage()

		dnsProxy.TrustedProxies = netip.PrefixFrom(addr, addr.BitLen())

		// Start listening.
		ctx := context.Background()
		err := dnsProxy.Start(ctx)
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

		hdrs := map[string]string{
			"X-Forwarded-For": strings.Join([]string{clientAddr.String(), proxyAddr.String()}, ","),
		}

		resp := sendTestDoHMessage(t, client, msg, hdrs)
		requireResponse(t, msg, resp)

		require.Equal(t, expectedClientIP, gotAddr)
	}

	t.Run("success", func(t *testing.T) {
		doRequest(t, proxyAddr, clientAddr)
	})

	t.Run("not_in_trusted", func(t *testing.T) {
		doRequest(t, netip.MustParseAddr("127.0.0.2"), proxyAddr)
	})
}

func TestAddrsFromRequest(t *testing.T) {
	var (
		theIP     = netip.AddrFrom4([4]byte{1, 2, 3, 4})
		anotherIP = netip.AddrFrom4([4]byte{1, 2, 3, 5})

		theIPStr     = theIP.String()
		anotherIPStr = anotherIP.String()
	)

	testCases := []struct {
		name    string
		hdrs    map[string]string
		wantIP  netip.Addr
		wantErr string
	}{{
		name: "cf-connecting-ip",
		hdrs: map[string]string{
			"CF-Connecting-IP": theIPStr,
		},
		wantIP:  theIP,
		wantErr: "",
	}, {
		name: "true-client-ip",
		hdrs: map[string]string{
			"True-Client-IP": theIPStr,
		},
		wantIP:  theIP,
		wantErr: "",
	}, {
		name: "x-real-ip",
		hdrs: map[string]string{
			"X-Real-IP": theIPStr,
		},
		wantIP:  theIP,
		wantErr: "",
	}, {
		name: "no_any",
		hdrs: map[string]string{
			"CF-Connecting-IP": "invalid",
			"True-Client-IP":   "invalid",
			"X-Real-IP":        "invalid",
		},
		wantIP:  netip.Addr{},
		wantErr: `ParseAddr(""): unable to parse IP`,
	}, {
		name: "priority",
		hdrs: map[string]string{
			"X-Forwarded-For":  strings.Join([]string{anotherIPStr, theIPStr}, ","),
			"True-Client-IP":   anotherIPStr,
			"X-Real-IP":        anotherIPStr,
			"CF-Connecting-IP": theIPStr,
		},
		wantIP:  theIP,
		wantErr: "",
	}, {
		name: "x-forwarded-for_simple",
		hdrs: map[string]string{
			"X-Forwarded-For": strings.Join([]string{anotherIPStr, theIPStr}, ","),
		},
		wantIP:  anotherIP,
		wantErr: "",
	}, {
		name: "x-forwarded-for_single",
		hdrs: map[string]string{
			"X-Forwarded-For": theIPStr,
		},
		wantIP:  theIP,
		wantErr: "",
	}, {
		name: "x-forwarded-for_invalid_proxy",
		hdrs: map[string]string{
			"X-Forwarded-For": strings.Join([]string{theIPStr, "invalid"}, ","),
		},
		wantIP:  theIP,
		wantErr: "",
	}, {
		name: "x-forwarded-for_empty",
		hdrs: map[string]string{
			"X-Forwarded-For": "",
		},
		wantIP:  netip.Addr{},
		wantErr: `ParseAddr(""): unable to parse IP`,
	}, {
		name: "x-forwarded-for_redundant_spaces",
		hdrs: map[string]string{
			"X-Forwarded-For": "  " + theIPStr + "   ,\t" + anotherIPStr,
		},
		wantIP:  theIP,
		wantErr: "",
	}, {
		name: "cf-connecting-ip_redundant_spaces",
		hdrs: map[string]string{
			"CF-Connecting-IP": "  " + theIPStr + "\t",
		},
		wantIP:  theIP,
		wantErr: "",
	}}

	for _, tc := range testCases {
		r, err := http.NewRequest(http.MethodGet, "localhost", nil)
		require.NoError(t, err)

		for h, v := range tc.hdrs {
			r.Header.Set(h, v)
		}

		t.Run(tc.name, func(t *testing.T) {
			var ip netip.Addr
			ip, err = realIPFromHdrs(r)
			testutil.AssertErrorMsg(t, tc.wantErr, err)

			assert.Equal(t, tc.wantIP, ip)
		})
	}
}

func TestRemoteAddr(t *testing.T) {
	const thePort = 4321

	var (
		theIP     = netip.AddrFrom4([4]byte{1, 2, 3, 4})
		anotherIP = netip.AddrFrom4([4]byte{1, 2, 3, 5})
		thirdIP   = netip.AddrFrom4([4]byte{1, 2, 3, 6})

		theIPStr     = theIP.String()
		anotherIPStr = anotherIP.String()
		thirdIPStr   = thirdIP.String()
	)

	rAddr := netip.AddrPortFrom(theIP, thePort)

	testCases := []struct {
		name       string
		remoteAddr string
		hdrs       map[string]string
		wantErr    string
		wantIP     netip.AddrPort
		wantProxy  netip.AddrPort
	}{{
		name:       "no_proxy",
		remoteAddr: rAddr.String(),
		hdrs:       nil,
		wantErr:    "",
		wantIP:     netip.AddrPortFrom(theIP, thePort),
		wantProxy:  netip.AddrPort{},
	}, {
		name:       "proxied_with_cloudflare",
		remoteAddr: rAddr.String(),
		hdrs: map[string]string{
			"CF-Connecting-IP": anotherIPStr,
		},
		wantErr:   "",
		wantIP:    netip.AddrPortFrom(anotherIP, 0),
		wantProxy: netip.AddrPortFrom(theIP, thePort),
	}, {
		name:       "proxied_once",
		remoteAddr: rAddr.String(),
		hdrs: map[string]string{
			"X-Forwarded-For": anotherIPStr,
		},
		wantErr:   "",
		wantIP:    netip.AddrPortFrom(anotherIP, 0),
		wantProxy: netip.AddrPortFrom(theIP, thePort),
	}, {
		name:       "proxied_multiple",
		remoteAddr: rAddr.String(),
		hdrs: map[string]string{
			"X-Forwarded-For": strings.Join([]string{anotherIPStr, thirdIPStr}, ","),
		},
		wantErr:   "",
		wantIP:    netip.AddrPortFrom(anotherIP, 0),
		wantProxy: netip.AddrPortFrom(theIP, thePort),
	}, {
		name:       "no_port",
		remoteAddr: theIPStr,
		hdrs:       nil,
		wantErr:    "not an ip:port",
		wantIP:     netip.AddrPort{},
		wantProxy:  netip.AddrPort{},
	}, {
		name:       "bad_port",
		remoteAddr: theIPStr + ":notport",
		hdrs:       nil,
		wantErr:    `invalid port "notport" parsing "1.2.3.4:notport"`,
		wantIP:     netip.AddrPort{},
		wantProxy:  netip.AddrPort{},
	}, {
		name:       "bad_host",
		remoteAddr: "host:1",
		hdrs:       nil,
		wantErr:    `ParseAddr("host"): unable to parse IP`,
		wantIP:     netip.AddrPort{},
		wantProxy:  netip.AddrPort{},
	}, {
		name:       "bad_proxied_host",
		remoteAddr: "host:1",
		hdrs: map[string]string{
			"CF-Connecting-IP": theIPStr,
		},
		wantErr:   `ParseAddr("host"): unable to parse IP`,
		wantIP:    netip.AddrPort{},
		wantProxy: netip.AddrPort{},
	}}

	l := slogutil.NewDiscardLogger()

	for _, tc := range testCases {
		r, err := http.NewRequest(http.MethodGet, "localhost", nil)
		require.NoError(t, err)

		r.RemoteAddr = tc.remoteAddr
		for h, v := range tc.hdrs {
			r.Header.Set(h, v)
		}

		t.Run(tc.name, func(t *testing.T) {
			var addr, prx netip.AddrPort
			addr, prx, err = remoteAddr(r, l)
			if tc.wantErr != "" {
				testutil.AssertErrorMsg(t, tc.wantErr, err)

				return
			}

			require.NoError(t, err)
			assert.Equal(t, tc.wantIP, addr)
			assert.Equal(t, tc.wantProxy, prx)
		})
	}
}

// sendTestDoHMessage sends the specified DNS message using client and returns
// the DNS response.
func sendTestDoHMessage(
	t *testing.T,
	client *http.Client,
	m *dns.Msg,
	hdrs map[string]string,
) (resp *dns.Msg) {
	packed, err := m.Pack()
	require.NoError(t, err)

	u := url.URL{
		Scheme:   "https",
		Host:     tlsServerName,
		Path:     "/dns-query",
		RawQuery: fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(packed)),
	}

	method := http.MethodGet
	if _, ok := client.Transport.(*http3.Transport); ok {
		// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
		method = http3.MethodGet0RTT
	}

	req, err := http.NewRequest(method, u.String(), nil)
	require.NoError(t, err)

	req.Header.Set("Content-Type", "application/dns-message")
	req.Header.Set("Accept", "application/dns-message")

	for k, v := range hdrs {
		req.Header.Set(k, v)
	}

	httpResp, err := client.Do(req) // nolint:bodyclose
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, httpResp.Body.Close)

	require.True(
		t,
		httpResp.ProtoAtLeast(2, 0),
		"the proto is too old: %s",
		httpResp.Proto,
	)

	body, err := io.ReadAll(httpResp.Body)
	require.NoError(t, err)

	resp = &dns.Msg{}
	err = resp.Unpack(body)
	require.NoError(t, err)

	return resp
}

// createTestHTTPClient creates an *http.Client that will be used to send
// requests to the specified dnsProxy.
func createTestHTTPClient(dnsProxy *Proxy, caPem []byte, http3Enabled bool) (client *http.Client) {
	// prepare roots list so that the server cert was successfully validated.
	roots := x509.NewCertPool()
	roots.AppendCertsFromPEM(caPem)
	tlsClientConfig := &tls.Config{
		ServerName: tlsServerName,
		RootCAs:    roots,
	}

	var transport http.RoundTripper

	if http3Enabled {
		tlsClientConfig.NextProtos = []string{"h3"}

		transport = &http3.Transport{
			Dial: func(
				ctx context.Context,
				_ string,
				tlsCfg *tls.Config,
				cfg *quic.Config,
			) (quic.EarlyConnection, error) {
				addr := dnsProxy.Addr(ProtoHTTPS).String()
				return quic.DialAddrEarly(ctx, addr, tlsCfg, cfg)
			},
			TLSClientConfig:    tlsClientConfig,
			QUICConfig:         &quic.Config{},
			DisableCompression: true,
		}
	} else {
		dialer := &net.Dialer{
			Timeout: defaultTimeout,
		}
		dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) {
			// Route request to the DNS-over-HTTPS server address.
			return dialer.DialContext(ctx, network, dnsProxy.Addr(ProtoHTTPS).String())
		}

		tlsClientConfig.NextProtos = []string{"h2", "http/1.1"}
		transport = &http.Transport{
			TLSClientConfig:    tlsClientConfig,
			DisableCompression: true,
			DialContext:        dialContext,
			ForceAttemptHTTP2:  true,
		}
	}

	return &http.Client{
		Transport: transport,
		Timeout:   defaultTimeout,
	}
}
07070100000074000081A4000000000000000000000001679A649F00003D1D000000000000000000000000000000000000002500000000dnsproxy-0.75.0/proxy/server_quic.gopackage proxy

import (
	"context"
	"encoding/binary"
	"fmt"
	"io"
	"log/slog"
	"math"
	"net"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/syncutil"
	"github.com/bluele/gcache"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
)

// NextProtoDQ is the ALPN token for DoQ. During connection establishment,
// DNS/QUIC support is indicated by selecting the ALPN token "dq" in the
// crypto handshake.
// DoQ RFC: https://www.rfc-editor.org/rfc/rfc9250.html
const NextProtoDQ = "doq"

// compatProtoDQ is a list of ALPN tokens used by a QUIC connection.
// NextProtoDQ is the latest draft version supported by dnsproxy, but it also
// includes previous drafts.
var compatProtoDQ = []string{NextProtoDQ, "doq-i02", "doq-i00", "dq"}

// maxQUICIdleTimeout is maximum QUIC idle timeout.  The default value in
// quic-go is 30 seconds, but our internal tests show that a higher value works
// better for clients written with ngtcp2.
const maxQUICIdleTimeout = 5 * time.Minute

// quicAddrValidatorCacheSize is the size of the cache that we use in the QUIC
// address validator.  The value is chosen arbitrarily and we should consider
// making it configurable.
// TODO(ameshkov): make it configurable.
const quicAddrValidatorCacheSize = 1000

// quicAddrValidatorCacheTTL is time-to-live for cache items in the QUIC address
// validator.  The value is chosen arbitrarily and we should consider making it
// configurable.
// TODO(ameshkov): make it configurable.
const quicAddrValidatorCacheTTL = 30 * time.Minute

const (
	// DoQCodeNoError is used when the connection or stream needs to be closed,
	// but there is no error to signal.
	DoQCodeNoError quic.ApplicationErrorCode = 0
	// DoQCodeInternalError signals that the DoQ implementation encountered
	// an internal error and is incapable of pursuing the transaction or the
	// connection.
	DoQCodeInternalError quic.ApplicationErrorCode = 1
	// DoQCodeProtocolError signals that the DoQ implementation encountered
	// a protocol error and is forcibly aborting the connection.
	DoQCodeProtocolError quic.ApplicationErrorCode = 2
)

// createQUICListeners creates QUIC listeners for the DoQ server.
func (p *Proxy) createQUICListeners() error {
	for _, a := range p.QUICListenAddr {
		p.logger.Info("creating quic listener", "addr", a)

		conn, err := net.ListenUDP(bootstrap.NetworkUDP, a)
		if err != nil {
			return fmt.Errorf("listening to %s: %w", a, err)
		}

		p.quicConns = append(p.quicConns, conn)

		v := newQUICAddrValidator(quicAddrValidatorCacheSize, quicAddrValidatorCacheTTL)
		transport := &quic.Transport{
			Conn:                conn,
			VerifySourceAddress: v.requiresValidation,
		}

		tlsConfig := p.TLSConfig.Clone()
		tlsConfig.NextProtos = compatProtoDQ
		quicListen, err := transport.ListenEarly(
			tlsConfig,
			newServerQUICConfig(),
		)
		if err != nil {
			return fmt.Errorf("quic listener: %w", err)
		}

		p.quicTransports = append(p.quicTransports, transport)
		p.quicListen = append(p.quicListen, quicListen)

		p.logger.Info("listening quic", "addr", quicListen.Addr())
	}

	return nil
}

// quicPacketLoop listens for incoming QUIC packets.
//
// See also the comment on Proxy.requestsSema.
func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, reqSema syncutil.Semaphore) {
	p.logger.Info("entering dns-over-quic listener loop", "addr", l.Addr())

	for {
		ctx := context.Background()
		conn, err := l.Accept(ctx)
		if err != nil {
			logQUICError(ctx, "accepting quic conn", err, p.logger)

			break
		}

		err = reqSema.Acquire(ctx)
		if err != nil {
			p.logger.ErrorContext(
				ctx,
				"acquiring semaphore",
				"proto", ProtoQUIC,
				slogutil.KeyError, err,
			)

			break
		}
		go func() {
			defer reqSema.Release()

			p.handleQUICConnection(conn, reqSema)
		}()
	}
}

// logQUICError writes suitable log message for the given err.
func logQUICError(ctx context.Context, prefix string, err error, l *slog.Logger) {
	if isQUICErrorForDebugLog(err) {
		l.DebugContext(
			ctx,
			"closed or timed out",
			slogutil.KeyPrefix, prefix,
			slogutil.KeyError, err,
		)
	} else {
		l.ErrorContext(ctx, prefix, slogutil.KeyError, err)
	}
}

// handleQUICConnection handles a new QUIC connection.  It waits for new streams
// and passes them to handleQUICStream.
//
// See also the comment on Proxy.requestsSema.
func (p *Proxy) handleQUICConnection(conn quic.Connection, reqSema syncutil.Semaphore) {
	for {
		ctx := context.Background()

		// The stub to resolver DNS traffic follows a simple pattern in which
		// the client sends a query, and the server provides a response.  This
		// design specifies that for each subsequent query on a QUIC connection
		// the client MUST select the next available client-initiated
		// bidirectional stream.
		stream, err := conn.AcceptStream(ctx)
		if err != nil {
			logQUICError(ctx, "accepting quic stream", err, p.logger)

			// Close the connection to make sure resources are freed.
			closeQUICConn(conn, DoQCodeNoError, p.logger)

			return
		}

		err = reqSema.Acquire(ctx)
		if err != nil {
			p.logger.ErrorContext(ctx, "acquiring semaphore", slogutil.KeyError, err)

			// Close the connection to make sure resources are freed.
			closeQUICConn(conn, DoQCodeNoError, p.logger)

			return
		}
		go func() {
			defer reqSema.Release()

			p.handleQUICStream(ctx, stream, conn)

			// The server MUST send the response(s) on the same stream and MUST
			// indicate, after the last response, through the STREAM FIN
			// mechanism that no further data will be sent on that stream.
			_ = stream.Close()
		}()
	}
}

// handleQUICStream reads DNS queries from the stream, processes them,
// and writes back the response.
func (p *Proxy) handleQUICStream(ctx context.Context, stream quic.Stream, conn quic.Connection) {
	bufPtr := p.bytesPool.Get().(*[]byte)
	defer p.bytesPool.Put(bufPtr)

	// One query - one stream.
	// The client MUST select the next available client-initiated bidirectional
	// stream for each subsequent query on a QUIC connection.

	// err is not checked here because STREAM FIN sent by the client is
	// indicated as error here.  Instead, we should check the number of bytes
	// received.
	buf := *bufPtr
	n, err := readAll(stream, buf)

	// Note that io.EOF does not really mean that there's any error, this is
	// just a signal that there will be no data to read anymore from this
	// stream.
	if (err != nil && err != io.EOF) || n < minDNSPacketSize {
		logShortQUICRead(ctx, err, p.logger)

		return
	}

	// In theory, we should use ALPN to get the DoQ version properly. However,
	// since there are not too many versions now, we only check how the DNS
	// query is encoded. If it's sent with a 2-byte prefix, we consider this a
	// DoQ v1. Otherwise, a draft version.
	doqVersion := DoQv1
	req := &dns.Msg{}

	// Note that we support both the old drafts and the new RFC. In the old
	// draft DNS messages were not prefixed with the message length.
	packetLen := binary.BigEndian.Uint16(buf[:2])
	if packetLen == uint16(n-2) {
		err = req.Unpack(buf[2:])
	} else {
		err = req.Unpack(buf)
		doqVersion = DoQv1Draft
	}

	if err != nil {
		p.logger.ErrorContext(ctx, "unpacking quic packet", slogutil.KeyError, err)
		closeQUICConn(conn, DoQCodeProtocolError, p.logger)

		return
	}

	if !validQUICMsg(req, p.logger) {
		// If a peer encounters such an error condition, it is considered a
		// fatal error. It SHOULD forcibly abort the connection using QUIC's
		// CONNECTION_CLOSE mechanism and SHOULD use the DoQ error code
		// DOQ_PROTOCOL_ERROR.
		closeQUICConn(conn, DoQCodeProtocolError, p.logger)

		return
	}

	d := p.newDNSContext(ProtoQUIC, req, netutil.NetAddrToAddrPort(conn.RemoteAddr()))
	d.QUICStream = stream
	d.QUICConnection = conn
	d.DoQVersion = doqVersion

	err = p.handleDNSRequest(d)
	if err != nil {
		p.logger.DebugContext(
			ctx,
			"error handling dns request",
			"proto", d.Proto,
			slogutil.KeyError, err,
		)
	}
}

// respondQUIC writes a response to the QUIC stream.
func (p *Proxy) respondQUIC(d *DNSContext) error {
	resp := d.Res

	if resp == nil {
		// If no response has been written, close the QUIC connection now.
		closeQUICConn(d.QUICConnection, DoQCodeInternalError, p.logger)

		return errors.Error("no response to write")
	}

	bytes, err := resp.Pack()
	if err != nil {
		return fmt.Errorf("couldn't convert message into wire format: %w", err)
	}

	// Depending on the DoQ version with either write a 2-bytes prefixed message
	// or just write the message (for old draft versions).
	var buf []byte
	switch d.DoQVersion {
	case DoQv1:
		buf = proxyutil.AddPrefix(bytes)
	case DoQv1Draft:
		buf = bytes
	default:
		return fmt.Errorf("invalid protocol version: %d", d.DoQVersion)
	}

	n, err := d.QUICStream.Write(buf)
	if err != nil {
		return fmt.Errorf("conn.Write(): %w", err)
	}
	if n != len(buf) {
		return fmt.Errorf("conn.Write() returned with %d != %d", n, len(buf))
	}

	return nil
}

// validQUICMsg validates the incoming DNS message and returns false if
// something is wrong with the message.
func validQUICMsg(req *dns.Msg, l *slog.Logger) (ok bool) {
	// See https://www.rfc-editor.org/rfc/rfc9250.html#name-protocol-errors

	// 1. a client or server receives a message with a non-zero Message ID.
	//
	// We do consciously not validate this case since there are stub proxies
	// that are sending a non-zero Message IDs.

	// 2. a client or server receives a STREAM FIN before receiving all the
	// bytes for a message indicated in the 2-octet length field.
	// 3. a server receives more than one query on a stream
	//
	// These cases are covered earlier when unpacking the DNS message.

	// 4. the client or server does not indicate the expected STREAM FIN after
	// sending requests or responses (see Section 4.2).
	//
	// This is quite problematic to validate this case since this would imply
	// we have to wait until STREAM FIN is arrived before we start processing
	// the message. So we're consciously ignoring this case in this
	// implementation.

	// 5. an implementation receives a message containing the edns-tcp-keepalive
	// EDNS(0) Option [RFC7828] (see Section 5.5.2).
	if opt := req.IsEdns0(); opt != nil {
		for _, option := range opt.Option {
			// Check for EDNS TCP keepalive option
			if option.Option() == dns.EDNS0TCPKEEPALIVE {
				l.Debug("client sent edns0 tcp keepalive option")

				return false
			}
		}
	}

	// 6. a client or a server attempts to open a unidirectional QUIC stream.
	//
	// This case can only be handled when writing a response.

	// 7. a server receives a "replayable" transaction in 0-RTT data
	//
	// The information necessary to validate this is not exposed by quic-go.

	return true
}

// logShortQUICRead is a logging helper for short reads from a QUIC stream.
func logShortQUICRead(ctx context.Context, err error, l *slog.Logger) {
	if err == nil {
		l.InfoContext(ctx, "quic packet too short for dns query")

		return
	}

	logQUICError(ctx, "reading from quic stream", err, l)
}

const (
	// qCodeNoError is returned when the QUIC connection was gracefully closed
	// and there is no error to signal.
	qCodeNoError = quic.ApplicationErrorCode(quic.NoError)

	// qCodeApplicationErrorError is used for Initial and Handshake packets.
	// This error is considered as non-critical and will not be logged as error.
	qCodeApplicationErrorError = quic.ApplicationErrorCode(quic.ApplicationErrorErrorCode)
)

// isQUICErrorForDebugLog returns true if err is a non-critical error, most
// probably related to the current QUIC implementation. err must not be nil.
//
// TODO(ameshkov): re-test when updating quic-go.
func isQUICErrorForDebugLog(err error) (ok bool) {
	if errors.Is(err, quic.ErrServerClosed) {
		// This error is returned when the QUIC listener was closed by us. This
		// is an expected error, we don't need the detailed logs here.
		return true
	}

	var qAppErr *quic.ApplicationError
	if errors.As(err, &qAppErr) &&
		(qAppErr.ErrorCode == qCodeNoError || qAppErr.ErrorCode == qCodeApplicationErrorError) {
		// No need to have detailed logs for these error codes either.
		//
		// TODO(a.garipov): Consider adding other error codes.
		return true
	}

	if errors.Is(err, quic.Err0RTTRejected) {
		// This error is returned on AcceptStream calls when the server rejects
		// 0-RTT for some reason.  This is a common scenario, no need for extra
		// logs.
		return true
	}

	// This error is returned when we're trying to accept a new stream from a
	// connection that had no activity for over than the keep-alive timeout.
	// This is a common scenario, no need for extra logs.
	var qIdleErr *quic.IdleTimeoutError

	return errors.As(err, &qIdleErr)
}

// closeQUICConn quietly closes the QUIC connection.
func closeQUICConn(conn quic.Connection, code quic.ApplicationErrorCode, l *slog.Logger) {
	l.Debug("closing quic conn", "addr", conn.LocalAddr(), "code", code)

	err := conn.CloseWithError(code, "")
	if err != nil {
		l.Debug("closing quic connection", "code", code, slogutil.KeyError, err)
	}
}

// newServerQUICConfig creates *quic.Config populated with the default settings.
// This function is supposed to be used for both DoQ and DoH3 server.
func newServerQUICConfig() (conf *quic.Config) {
	return &quic.Config{
		MaxIdleTimeout:        maxQUICIdleTimeout,
		MaxIncomingStreams:    math.MaxUint16,
		MaxIncomingUniStreams: math.MaxUint16,
		// Enable 0-RTT by default for all connections on the server-side.
		Allow0RTT: true,
	}
}

// quicAddrValidator is a helper struct that holds a small LRU cache of
// addresses for which we do not require address validation.
type quicAddrValidator struct {
	cache gcache.Cache
	ttl   time.Duration
}

// newQUICAddrValidator initializes a new instance of *quicAddrValidator.
func newQUICAddrValidator(cacheSize int, ttl time.Duration) (v *quicAddrValidator) {
	return &quicAddrValidator{
		cache: gcache.New(cacheSize).LRU().Build(),
		ttl:   ttl,
	}
}

// requiresValidation determines if a QUIC Retry packet should be sent by the
// client. This allows the server to verify the client's address but increases
// the latency.
func (v *quicAddrValidator) requiresValidation(addr net.Addr) (ok bool) {
	// addr must be *net.UDPAddr here and if it's not we don't mind panic.
	key := addr.(*net.UDPAddr).IP.String()
	if v.cache.Has(key) {
		return false
	}

	err := v.cache.SetWithExpire(key, true, v.ttl)
	if err != nil {
		// Shouldn't happen, since we don't set a serialization function.
		panic(fmt.Errorf("quic validator: setting cache item: %w", err))
	}

	// Address not found in the cache so return true to make sure the server
	// will require address validation.
	return true
}

// readAll reads from r until an error or io.EOF into the specified buffer buf.
// A successful call returns err == nil, not err == io.EOF.  If the buffer is
// too small, it returns error io.ErrShortBuffer.  This function has some
// similarities to io.ReadAll, but it reads to the specified buffer and not
// allocates (and grows) a new one.  Also, it is completely different from
// io.ReadFull as that one reads the exact number of bytes (buffer length) and
// readAll reads until io.EOF or until the buffer is filled.
func readAll(r io.Reader, buf []byte) (n int, err error) {
	for {
		if n == len(buf) {
			return n, io.ErrShortBuffer
		}

		var read int
		read, err = r.Read(buf[n:])
		n += read

		if err != nil {
			if err == io.EOF {
				err = nil
			}
			return n, err
		}
	}
}
07070100000075000081A4000000000000000000000001679A649F00001AB3000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/proxy/server_quic_test.gopackage proxy

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"io"
	"net"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/stretchr/testify/require"
)

func TestQuicProxy(t *testing.T) {
	serverConfig, caPem := newTLSConfig(t)

	roots := x509.NewCertPool()
	roots.AppendCertsFromPEM(caPem)
	tlsConfig := &tls.Config{
		ServerName: tlsServerName,
		RootCAs:    roots,
		NextProtos: append([]string{NextProtoDQ}, compatProtoDQ...),
	}

	conf := &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		QUICListenAddr:         []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TLSConfig:              serverConfig,
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
	}

	var addr *net.UDPAddr
	t.Run("run", func(t *testing.T) {
		dnsProxy := mustNew(t, conf)

		ctx := context.Background()
		err := dnsProxy.Start(ctx)
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

		addr = testutil.RequireTypeAssert[*net.UDPAddr](t, dnsProxy.Addr(ProtoQUIC))

		conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, nil)
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, func() (err error) {
			return conn.CloseWithError(DoQCodeNoError, "")
		})

		for range 10 {
			sendTestQUICMessage(t, conn, DoQv1)

			// Send a message encoded for a draft version as well.
			sendTestQUICMessage(t, conn, DoQv1Draft)
		}
	})
	require.False(t, t.Failed())

	conf.QUICListenAddr = []*net.UDPAddr{addr}
	conf.UpstreamConfig = newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr)

	t.Run("rerun", func(t *testing.T) {
		dnsProxy := mustNew(t, conf)

		ctx := context.Background()
		err := dnsProxy.Start(ctx)
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

		conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, nil)
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, func() (err error) {
			return conn.CloseWithError(DoQCodeNoError, "")
		})

		sendTestQUICMessage(t, conn, DoQv1)

		// Send a message encoded for a draft version as well.
		sendTestQUICMessage(t, conn, DoQv1Draft)
	})
}

func TestQuicProxy_largePackets(t *testing.T) {
	serverConfig, caPem := newTLSConfig(t)
	dnsProxy := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		TLSListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		HTTPSListenAddr:        []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		QUICListenAddr:         []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TLSConfig:              serverConfig,
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
		// Make sure the request does not go to any real upstream.
		RequestHandler: func(_ *Proxy, d *DNSContext) (err error) {
			resp := &dns.Msg{}
			resp.SetReply(d.Req)
			resp.Answer = []dns.RR{&dns.A{
				Hdr: dns.RR_Header{
					Name:   d.Req.Question[0].Name,
					Rrtype: dns.TypeA,
					Class:  dns.ClassINET,
				},
				A: net.IP{8, 8, 8, 8},
			}}
			d.Res = resp

			return nil
		},
	})

	// Start listening.
	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	roots := x509.NewCertPool()
	roots.AppendCertsFromPEM(caPem)
	tlsConfig := &tls.Config{
		ServerName: tlsServerName,
		RootCAs:    roots,
		NextProtos: append([]string{NextProtoDQ}, compatProtoDQ...),
	}

	// Create a DNS-over-QUIC client connection.
	addr := dnsProxy.Addr(ProtoQUIC)

	// Open a QUIC connection.
	conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, nil)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) {
		return conn.CloseWithError(DoQCodeNoError, "")
	})

	// Create a test message large enough to take multiple QUIC frames.
	msg := newTestMessage()
	msg.Extra = []dns.RR{
		&dns.OPT{
			Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT, Class: 4096},
			Option: []dns.EDNS0{
				&dns.EDNS0_PADDING{Padding: make([]byte, 4096)},
			},
		},
	}

	resp := sendQUICMessage(t, msg, conn, DoQv1)
	requireResponse(t, msg, resp)
}

// sendQUICMessage sends msg to the specified QUIC connection.
func sendQUICMessage(
	t *testing.T,
	msg *dns.Msg,
	conn quic.Connection,
	doqVersion DoQVersion,
) (resp *dns.Msg) {
	// Open a new QUIC stream to write there a test DNS query.
	stream, err := conn.OpenStreamSync(context.Background())
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, stream.Close)

	packedMsg, err := msg.Pack()
	require.NoError(t, err)

	buf := packedMsg
	if doqVersion == DoQv1 {
		buf = proxyutil.AddPrefix(packedMsg)
	}

	// Send the DNS query to the stream.
	err = writeQUICStream(buf, stream)
	require.NoError(t, err)

	// Close closes the write-direction of the stream and sends
	// a STREAM FIN packet.
	_ = stream.Close()

	// Now read the response from the stream.
	respBytes := make([]byte, 64*1024)
	n, err := stream.Read(respBytes)
	if err != nil {
		require.ErrorIs(t, err, io.EOF)
	}
	require.Greater(t, n, minDNSPacketSize)

	// Unpack the DNS response.
	resp = new(dns.Msg)
	if doqVersion == DoQv1 {
		err = resp.Unpack(respBytes[2:])
	} else {
		err = resp.Unpack(respBytes)
	}
	require.NoError(t, err)

	return resp
}

// writeQUICStream writes buf to the specified QUIC stream in chunks.  This way
// it is possible to test how the server deals with chunked DNS messages.
func writeQUICStream(buf []byte, stream quic.Stream) (err error) {
	// Send the DNS query to the stream and split it into chunks of up
	// to 400 bytes.  400 is an arbitrary chosen value.
	chunkSize := 400
	for i := 0; i < len(buf); i += chunkSize {
		chunkStart := i
		chunkEnd := i + chunkSize
		if chunkEnd > len(buf) {
			chunkEnd = len(buf)
		}

		_, err = stream.Write(buf[chunkStart:chunkEnd])
		if err != nil {
			return err
		}

		if len(buf) > chunkSize {
			// Emulate network latency.
			time.Sleep(time.Millisecond)
		}
	}

	return nil
}

// sendTestQUICMessage send a test message to the specified QUIC connection.
func sendTestQUICMessage(t *testing.T, conn quic.Connection, doqVersion DoQVersion) {
	msg := newTestMessage()
	resp := sendQUICMessage(t, msg, conn, doqVersion)
	requireResponse(t, msg, resp)
}
07070100000076000081A4000000000000000000000001679A649F000014B7000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/server_tcp.gopackage proxy

import (
	"context"
	"crypto/tls"
	"encoding/binary"
	"fmt"
	"io"
	"net"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/syncutil"
	"github.com/miekg/dns"
)

func (p *Proxy) createTCPListeners(ctx context.Context) (err error) {
	for _, a := range p.TCPListenAddr {
		p.logger.Info("creating tcp server socket", "addr", a)

		lsnr, lErr := proxynetutil.ListenConfig(p.logger).Listen(
			ctx,
			bootstrap.NetworkTCP,
			a.String(),
		)
		if lErr != nil {
			return fmt.Errorf("listening to tcp socket: %w", lErr)
		}

		tcpListener, ok := lsnr.(*net.TCPListener)
		if !ok {
			return fmt.Errorf("wrong listener type on tcp addr %s: %T", a, lsnr)
		}

		p.tcpListen = append(p.tcpListen, tcpListener)

		p.logger.Info("listening to tcp", "addr", tcpListener.Addr())
	}

	return nil
}

func (p *Proxy) createTLSListeners() (err error) {
	for _, a := range p.TLSListenAddr {
		p.logger.Info("creating tls server socket", "addr", a)

		var tcpListen *net.TCPListener
		tcpListen, err = net.ListenTCP("tcp", a)
		if err != nil {
			return fmt.Errorf("listening on tls addr %s: %w", a, err)
		}

		l := tls.NewListener(tcpListen, p.TLSConfig)
		p.tlsListen = append(p.tlsListen, l)

		p.logger.Info("listening to tls", "addr", l.Addr())
	}

	return nil
}

// tcpPacketLoop listens for incoming TCP packets.  proto must be either
// [ProtoTCP] or [ProtoTLS].
//
// See also the comment on Proxy.requestsSema.
func (p *Proxy) tcpPacketLoop(l net.Listener, proto Proto, reqSema syncutil.Semaphore) {
	p.logger.Info("entering listener loop", "proto", proto, "addr", l.Addr())

	for {
		clientConn, err := l.Accept()
		if err != nil {
			if errors.Is(err, net.ErrClosed) {
				p.logger.Debug("tcp connection closed", "addr", l.Addr())
			} else {
				p.logger.Error("reading from tcp", slogutil.KeyError, err)
			}

			break
		}

		// TODO(d.kolyshev): Pass and use context from above.
		err = reqSema.Acquire(context.Background())
		if err != nil {
			p.logger.Error("acquiring semaphore", "proto", ProtoTCP, slogutil.KeyError, err)

			break
		}

		go p.handleTCPConnection(clientConn, proto, reqSema)
	}
}

// handleTCPConnection starts a loop that handles an incoming TCP connection.
// proto must be either [ProtoTCP] or [ProtoTLS].
func (p *Proxy) handleTCPConnection(conn net.Conn, proto Proto, reqSema syncutil.Semaphore) {
	defer slogutil.RecoverAndLog(context.TODO(), p.logger)
	defer reqSema.Release()
	defer func() {
		err := conn.Close()
		if err != nil {
			logWithNonCrit(err, "closing conn", ProtoTCP, p.logger)
		}
	}()

	p.logger.Debug("handling new request", "proto", proto, "raddr", conn.RemoteAddr())

	for p.isStarted() {
		err := conn.SetDeadline(time.Now().Add(defaultTimeout))
		if err != nil {
			// Consider deadline errors non-critical.
			logWithNonCrit(err, "setting deadline", ProtoTCP, p.logger)
		}

		req := p.readDNSReq(conn)
		if req == nil {
			return
		}

		d := p.newDNSContext(proto, req, netutil.NetAddrToAddrPort(conn.RemoteAddr()))
		d.Conn = conn

		err = p.handleDNSRequest(d)
		if err != nil {
			logWithNonCrit(err, "handling request", ProtoTCP, p.logger)
		}
	}
}

// readDNSReq returns DNS request message from the given connection or nil if
// it failed to read it.  Properly logs the error if it happened.
func (p *Proxy) readDNSReq(conn net.Conn) (req *dns.Msg) {
	packet, err := readPrefixed(conn)
	if err != nil {
		logWithNonCrit(err, "reading msg", ProtoTCP, p.logger)

		return nil
	}

	req = &dns.Msg{}
	err = req.Unpack(packet)
	if err != nil {
		p.logger.Error("handling tcp; unpacking msg", slogutil.KeyError, err)

		return nil
	}

	return req
}

// errTooLarge means that a DNS message is larger than 64KiB.
const errTooLarge errors.Error = "dns message is too large"

// readPrefixed reads a DNS message with a 2-byte prefix containing message
// length from conn.
func readPrefixed(conn net.Conn) (b []byte, err error) {
	l := make([]byte, 2)
	_, err = conn.Read(l)
	if err != nil {
		return nil, fmt.Errorf("reading len: %w", err)
	}

	packetLen := binary.BigEndian.Uint16(l)
	if packetLen > dns.MaxMsgSize {
		return nil, errTooLarge
	}

	b = make([]byte, packetLen)
	_, err = io.ReadFull(conn, b)
	if err != nil {
		return nil, fmt.Errorf("reading msg: %w", err)
	}

	return b, nil
}

// Writes a response to the TCP (or TLS) client
func (p *Proxy) respondTCP(d *DNSContext) error {
	resp := d.Res
	conn := d.Conn

	if resp == nil {
		// If no response has been written, close the connection right away
		return conn.Close()
	}

	bytes, err := resp.Pack()
	if err != nil {
		return fmt.Errorf("packing message: %w", err)
	}

	err = writePrefixed(bytes, conn)
	if err != nil && !errors.Is(err, net.ErrClosed) {
		return fmt.Errorf("writing message: %w", err)
	}

	return nil
}

// writePrefixed writes a DNS message to a TCP connection it first writes
// a 2-byte prefix followed by the message itself.
func writePrefixed(b []byte, conn net.Conn) (err error) {
	l := make([]byte, 2)
	binary.BigEndian.PutUint16(l, uint16(len(b)))
	_, err = (&net.Buffers{l, b}).WriteTo(conn)

	return err
}
07070100000077000081A4000000000000000000000001679A649F00000692000000000000000000000000000000000000002900000000dnsproxy-0.75.0/proxy/server_tcp_test.gopackage proxy

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"net"
	"testing"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/require"
)

func TestTcpProxy(t *testing.T) {
	dnsProxy := mustStartDefaultProxy(t)

	// Create a DNS-over-TCP client connection
	addr := dnsProxy.Addr(ProtoTCP)
	conn, err := dns.Dial("tcp", addr.String())
	require.NoError(t, err)

	sendTestMessages(t, conn)
}

func TestTlsProxy(t *testing.T) {
	serverConfig, caPem := newTLSConfig(t)
	dnsProxy := mustNew(t, &Config{
		Logger:                 slogutil.NewDiscardLogger(),
		TLSListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		HTTPSListenAddr:        []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		QUICListenAddr:         []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TLSConfig:              serverConfig,
		UpstreamConfig:         newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr),
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
	})

	// Start listening
	ctx := context.Background()
	err := dnsProxy.Start(ctx)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) })

	roots := x509.NewCertPool()
	roots.AppendCertsFromPEM(caPem)
	tlsConfig := &tls.Config{ServerName: tlsServerName, RootCAs: roots}

	// Create a DNS-over-TLS client connection
	addr := dnsProxy.Addr(ProtoTLS)
	conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
	require.NoError(t, err)

	sendTestMessages(t, conn)
}
07070100000078000081A4000000000000000000000001679A649F00001102000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/server_udp.gopackage proxy

import (
	"context"
	"fmt"
	"log/slog"
	"net"
	"net/netip"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/syncutil"
	"github.com/miekg/dns"
)

func (p *Proxy) createUDPListeners(ctx context.Context) (err error) {
	for _, a := range p.UDPListenAddr {
		var pc *net.UDPConn
		pc, sErr := p.udpCreate(ctx, a)
		if sErr != nil {
			return fmt.Errorf("listening on udp addr %s: %w", a, sErr)
		}

		p.udpListen = append(p.udpListen, pc)
	}

	return nil
}

// udpCreate - create a UDP listening socket
func (p *Proxy) udpCreate(ctx context.Context, udpAddr *net.UDPAddr) (*net.UDPConn, error) {
	p.logger.InfoContext(ctx, "creating udp server socket", "addr", udpAddr)

	packetConn, err := proxynetutil.ListenConfig(p.logger).ListenPacket(
		ctx,
		bootstrap.NetworkUDP,
		udpAddr.String(),
	)
	if err != nil {
		return nil, fmt.Errorf("listening to udp socket: %w", err)
	}

	udpListen := packetConn.(*net.UDPConn)
	if p.Config.UDPBufferSize > 0 {
		err = udpListen.SetReadBuffer(p.Config.UDPBufferSize)
		if err != nil {
			_ = udpListen.Close()

			return nil, fmt.Errorf("setting udp buf size: %w", err)
		}
	}

	err = proxynetutil.UDPSetOptions(udpListen)
	if err != nil {
		_ = udpListen.Close()

		return nil, fmt.Errorf("setting udp opts: %w", err)
	}

	p.logger.InfoContext(ctx, "listening to udp", "addr", udpListen.LocalAddr())

	return udpListen, nil
}

// udpPacketLoop listens for incoming UDP packets.
//
// See also the comment on Proxy.requestsSema.
func (p *Proxy) udpPacketLoop(conn *net.UDPConn, reqSema syncutil.Semaphore) {
	p.logger.Info("entering udp listener loop", "addr", conn.LocalAddr())

	b := make([]byte, dns.MaxMsgSize)
	for p.isStarted() {
		n, localIP, remoteAddr, err := proxynetutil.UDPRead(conn, b, p.udpOOBSize)
		// The documentation says to handle the packet even if err occurs.
		if n > 0 {
			// Make a copy of all bytes because ReadFrom() will overwrite the
			// contents of b on the next call.  We need that contents to sustain
			// the call because we're handling them in goroutines.
			packet := make([]byte, n)
			copy(packet, b)

			// TODO(d.kolyshev): Pass and use context from above.
			sErr := reqSema.Acquire(context.Background())
			if sErr != nil {
				p.logger.Error("acquiring semaphore", "proto", ProtoUDP, slogutil.KeyError, sErr)

				break
			}
			go func() {
				defer reqSema.Release()

				p.udpHandlePacket(packet, localIP, remoteAddr, conn)
			}()
		}

		if err != nil {
			logUDPConnError(err, conn, p.logger)

			break
		}
	}
}

// logUDPConnError writes suitable log message for given err.
func logUDPConnError(err error, conn *net.UDPConn, l *slog.Logger) {
	if errors.Is(err, net.ErrClosed) {
		l.Debug("udp connection closed", "addr", conn.LocalAddr())
	} else {
		l.Error("reading from udp", slogutil.KeyError, err)
	}
}

// udpHandlePacket processes the incoming UDP packet and sends a DNS response
func (p *Proxy) udpHandlePacket(
	packet []byte,
	localIP netip.Addr,
	remoteAddr *net.UDPAddr,
	conn *net.UDPConn,
) {
	p.logger.Debug("handling new udp packet", "raddr", remoteAddr)

	req := &dns.Msg{}
	err := req.Unpack(packet)
	if err != nil {
		p.logger.Error("unpacking udp packet", slogutil.KeyError, err)

		return
	}

	d := p.newDNSContext(ProtoUDP, req, netutil.NetAddrToAddrPort(remoteAddr))
	d.Conn = conn
	d.localIP = localIP

	err = p.handleDNSRequest(d)
	if err != nil {
		p.logger.Debug("handling dns request", "proto", d.Proto, slogutil.KeyError, err)
	}
}

// Writes a response to the UDP client
func (p *Proxy) respondUDP(d *DNSContext) error {
	resp := d.Res

	if resp == nil {
		// Do nothing if no response has been written
		return nil
	}

	bytes, err := resp.Pack()
	if err != nil {
		return fmt.Errorf("packing message: %w", err)
	}

	conn := d.Conn.(*net.UDPConn)
	rAddr := net.UDPAddrFromAddrPort(d.Addr)
	n, err := proxynetutil.UDPWrite(bytes, conn, rAddr, d.localIP)
	if err != nil {
		if errors.Is(err, net.ErrClosed) {
			return nil
		}

		return fmt.Errorf("writing message: %w", err)
	}

	if n != len(bytes) {
		return fmt.Errorf("udpWrite() returned with %d != %d", n, len(bytes))
	}

	return nil
}
07070100000079000081A4000000000000000000000001679A649F00000160000000000000000000000000000000000000002900000000dnsproxy-0.75.0/proxy/server_udp_test.gopackage proxy

import (
	"testing"

	"github.com/miekg/dns"
	"github.com/stretchr/testify/require"
)

func TestUdpProxy(t *testing.T) {
	dnsProxy := mustStartDefaultProxy(t)

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	conn, err := dns.Dial("udp", addr.String())
	require.NoError(t, err)

	sendTestMessages(t, conn)
}
0707010000007A000081A4000000000000000000000001679A649F000013F0000000000000000000000000000000000000001F00000000dnsproxy-0.75.0/proxy/stats.gopackage proxy

import (
	"fmt"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/miekg/dns"
)

// upstreamWithStats is a wrapper around the [upstream.Upstream] interface that
// gathers statistics.
type upstreamWithStats struct {
	// upstream is the upstream DNS resolver.
	upstream upstream.Upstream

	// err is the DNS lookup error, if any.
	err error

	// queryDuration is the duration of the successful DNS lookup.
	queryDuration time.Duration
}

// type check
var _ upstream.Upstream = (*upstreamWithStats)(nil)

// Exchange implements the [upstream.Upstream] for *upstreamWithStats.
func (u *upstreamWithStats) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	start := time.Now()
	resp, err = u.upstream.Exchange(req)
	u.err = err
	u.queryDuration = time.Since(start)

	return resp, err
}

// Address implements the [upstream.Upstream] for *upstreamWithStats.
func (u *upstreamWithStats) Address() (addr string) {
	return u.upstream.Address()
}

// Close implements the [upstream.Upstream] for *upstreamWithStats.
func (u *upstreamWithStats) Close() (err error) {
	return u.upstream.Close()
}

// upstreamsWithStats takes a list of upstreams, wraps each upstream with
// [upstreamWithStats] to gather statistics, and returns the wrapped upstreams.
func upstreamsWithStats(upstreams []upstream.Upstream) (wrapped []upstream.Upstream) {
	wrapped = make([]upstream.Upstream, 0, len(upstreams))
	for _, u := range upstreams {
		wrapped = append(wrapped, &upstreamWithStats{upstream: u})
	}

	return wrapped
}

// QueryStatistics contains the DNS query statistics for both the upstream and
// fallback DNS servers.
type QueryStatistics struct {
	main     []*UpstreamStatistics
	fallback []*UpstreamStatistics
}

// cachedQueryStatistics returns the DNS query statistics for cached queries.
func cachedQueryStatistics(addr string) (s *QueryStatistics) {
	return &QueryStatistics{
		main: []*UpstreamStatistics{{
			Address:  addr,
			IsCached: true,
		}},
	}
}

// Main returns the DNS query statistics for the upstream DNS servers.
func (s *QueryStatistics) Main() (us []*UpstreamStatistics) {
	return s.main
}

// Fallback returns the DNS query statistics for the fallback DNS servers.
func (s *QueryStatistics) Fallback() (us []*UpstreamStatistics) {
	return s.fallback
}

// collectQueryStats gathers the statistics from the wrapped upstreams.
// resolver is an upstream DNS resolver that successfully resolved the request,
// if any.  Provided upstreams must be of type [*upstreamWithStats].  unwrapped
// is the unwrapped resolver, see [upstreamWithStats.upstream].  The returned
// statistics depend on whether the DNS request was successfully resolved and
// the upstream mode, see [DNSContext.QueryStatistics].
func collectQueryStats(
	mode UpstreamMode,
	resolver upstream.Upstream,
	upstreams []upstream.Upstream,
	fallbacks []upstream.Upstream,
) (unwrapped upstream.Upstream, stats *QueryStatistics) {
	var wrapped *upstreamWithStats
	if resolver != nil {
		var ok bool
		wrapped, ok = resolver.(*upstreamWithStats)
		if !ok {
			// Should never happen.
			panic(fmt.Errorf("unexpected type %T", resolver))
		}

		unwrapped = wrapped.upstream
	}

	// The DNS query was not resolved.
	if wrapped == nil {
		return nil, &QueryStatistics{
			main:     collectUpstreamStats(upstreams...),
			fallback: collectUpstreamStats(fallbacks...),
		}
	}

	// The DNS query was successfully resolved by main resolver and the upstream
	// mode is [UpstreamModeFastestAddr].
	if mode == UpstreamModeFastestAddr && len(fallbacks) == 0 {
		return unwrapped, &QueryStatistics{
			main: collectUpstreamStats(upstreams...),
		}
	}

	// The DNS query was resolved by fallback resolver.
	if len(fallbacks) > 0 {
		return unwrapped, &QueryStatistics{
			main:     collectUpstreamStats(upstreams...),
			fallback: collectUpstreamStats(wrapped),
		}
	}

	// The DNS query was successfully resolved by main resolver.
	return unwrapped, &QueryStatistics{
		main: collectUpstreamStats(wrapped),
	}
}

// UpstreamStatistics contains the DNS query statistics.
type UpstreamStatistics struct {
	// Error is the DNS lookup error, if any.
	Error error

	// Address is the address of the upstream DNS resolver.
	//
	// TODO(s.chzhen):  Use [upstream.Upstream] when [cacheItem] starts to
	// contain one.
	Address string

	// QueryDuration is the duration of the successful DNS lookup.
	QueryDuration time.Duration

	// IsCached indicates whether the response was served from a cache.
	IsCached bool
}

// collectUpstreamStats gathers the upstream statistics from the list of wrapped
// upstreams.  upstreams must be of type *upstreamWithStats.
func collectUpstreamStats(upstreams ...upstream.Upstream) (stats []*UpstreamStatistics) {
	stats = make([]*UpstreamStatistics, 0, len(upstreams))

	for _, u := range upstreams {
		w, ok := u.(*upstreamWithStats)
		if !ok {
			// Should never happen.
			panic(fmt.Errorf("unexpected type %T", u))
		}

		stats = append(stats, &UpstreamStatistics{
			Error:         w.err,
			Address:       w.Address(),
			QueryDuration: w.queryDuration,
		})
	}

	return stats
}
0707010000007B000081A4000000000000000000000001679A649F00001E58000000000000000000000000000000000000002400000000dnsproxy-0.75.0/proxy/stats_test.gopackage proxy_test

import (
	"net"
	"net/netip"
	"testing"

	"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestCollectQueryStats(t *testing.T) {
	const (
		listenIP = "127.0.0.1"
	)

	var (
		testReq = &dns.Msg{
			Question: []dns.Question{{
				Name:   "test.",
				Qtype:  dns.TypeA,
				Qclass: dns.ClassINET,
			}},
		}

		defaultTrustedProxies netutil.SubnetSet = netutil.SliceSubnetSet{
			netip.MustParsePrefix("0.0.0.0/0"),
			netip.MustParsePrefix("::0/0"),
		}

		localhostAnyPort = netip.MustParseAddrPort(netutil.JoinHostPort(listenIP, 0))
	)

	ups := &dnsproxytest.FakeUpstream{
		OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			return (&dns.Msg{}).SetReply(req), nil
		},
		OnAddress: func() (addr string) { return "upstream" },
		OnClose:   func() (err error) { return nil },
	}

	failUps := &dnsproxytest.FakeUpstream{
		OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			return nil, errors.Error("exchange error")
		},
		OnAddress: func() (addr string) { return "fail.upstream" },
		OnClose:   func() (err error) { return nil },
	}

	conf := &proxy.Config{
		Logger:                 slogutil.NewDiscardLogger(),
		UDPListenAddr:          []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
		TCPListenAddr:          []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
		TrustedProxies:         defaultTrustedProxies,
		RatelimitSubnetLenIPv4: 24,
		RatelimitSubnetLenIPv6: 64,
	}

	testCases := []struct {
		wantErr           assert.ErrorAssertionFunc
		wantMainErr       assert.BoolAssertionFunc
		wantFallbackErr   assert.BoolAssertionFunc
		config            *proxy.UpstreamConfig
		fallbackConfig    *proxy.UpstreamConfig
		name              string
		mode              proxy.UpstreamMode
		wantMainCount     int
		wantFallbackCount int
	}{{
		wantErr:         assert.NoError,
		wantMainErr:     assert.False,
		wantFallbackErr: assert.False,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		name:              "load_balance_success",
		mode:              proxy.UpstreamModeLoadBalance,
		wantMainCount:     1,
		wantFallbackCount: 0,
	}, {
		wantErr:         assert.Error,
		wantMainErr:     assert.True,
		wantFallbackErr: assert.True,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{failUps},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{failUps, failUps},
		},
		name:              "load_balance_bad",
		mode:              proxy.UpstreamModeLoadBalance,
		wantMainCount:     1,
		wantFallbackCount: 2,
	}, {
		wantErr:         assert.NoError,
		wantMainErr:     assert.False,
		wantFallbackErr: assert.False,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups, failUps},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		name:              "parallel_success",
		mode:              proxy.UpstreamModeParallel,
		wantMainCount:     1,
		wantFallbackCount: 0,
	}, {
		wantErr:         assert.NoError,
		wantMainErr:     assert.True,
		wantFallbackErr: assert.False,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{failUps},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		name:              "parallel_bad_fallback_success",
		mode:              proxy.UpstreamModeParallel,
		wantMainCount:     1,
		wantFallbackCount: 1,
	}, {
		wantErr:         assert.Error,
		wantMainErr:     assert.True,
		wantFallbackErr: assert.True,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{failUps, failUps},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{failUps, failUps, failUps},
		},
		name:              "parallel_bad",
		mode:              proxy.UpstreamModeParallel,
		wantMainCount:     2,
		wantFallbackCount: 3,
	}, {
		wantErr:         assert.NoError,
		wantMainErr:     assert.False,
		wantFallbackErr: assert.False,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		name:              "fastest_single_success",
		mode:              proxy.UpstreamModeFastestAddr,
		wantMainCount:     1,
		wantFallbackCount: 0,
	}, {
		wantErr:         assert.NoError,
		wantMainErr:     assert.False,
		wantFallbackErr: assert.False,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups, ups},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		name:              "fastest_multiple_success",
		mode:              proxy.UpstreamModeFastestAddr,
		wantMainCount:     2,
		wantFallbackCount: 0,
	}, {
		wantErr:         assert.NoError,
		wantMainErr:     assert.True,
		wantFallbackErr: assert.False,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups, failUps},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		name:              "fastest_mixed_success",
		mode:              proxy.UpstreamModeFastestAddr,
		wantMainCount:     2,
		wantFallbackCount: 0,
	}, {
		wantErr:         assert.Error,
		wantMainErr:     assert.True,
		wantFallbackErr: assert.True,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{failUps, failUps},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{failUps, failUps, failUps},
		},
		name:              "fastest_multiple_bad",
		mode:              proxy.UpstreamModeFastestAddr,
		wantMainCount:     2,
		wantFallbackCount: 3,
	}, {
		wantErr:         assert.NoError,
		wantMainErr:     assert.True,
		wantFallbackErr: assert.False,
		config: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{failUps, failUps},
		},
		fallbackConfig: &proxy.UpstreamConfig{
			Upstreams: []upstream.Upstream{ups},
		},
		name:              "fastest_bad_fallback_success",
		mode:              proxy.UpstreamModeFastestAddr,
		wantMainCount:     2,
		wantFallbackCount: 1,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			conf.UpstreamConfig = tc.config
			conf.Fallbacks = tc.fallbackConfig
			conf.UpstreamMode = tc.mode

			p, err := proxy.New(conf)
			require.NoError(t, err)

			d := &proxy.DNSContext{Req: testReq}

			err = p.Resolve(d)
			tc.wantErr(t, err)

			stats := d.QueryStatistics()
			assertQueryStats(
				t,
				stats,
				tc.wantMainCount,
				tc.wantMainErr,
				tc.wantFallbackCount,
				tc.wantFallbackErr,
			)
		})
	}
}

// assertQueryStats asserts the statistics using the provided parameters.
func assertQueryStats(
	t *testing.T,
	stats *proxy.QueryStatistics,
	wantMainCount int,
	wantMainErr assert.BoolAssertionFunc,
	wantFallbackCount int,
	wantFallbackErr assert.BoolAssertionFunc,
) {
	t.Helper()

	main := stats.Main()
	assert.Lenf(t, main, wantMainCount, "main stats count")

	fallback := stats.Fallback()
	assert.Lenf(t, fallback, wantFallbackCount, "fallback stats count")

	wantMainErr(t, isErrorInStats(main), "main err")
	wantFallbackErr(t, isErrorInStats(fallback), "fallback err")
}

// isErrorInStats is a helper function for tests that returns true if the
// upstream statistics contain an DNS lookup error.
func isErrorInStats(stats []*proxy.UpstreamStatistics) (ok bool) {
	for _, u := range stats {
		if u.Error != nil {
			return true
		}
	}

	return false
}
0707010000007C000081A4000000000000000000000001679A649F000005ED000000000000000000000000000000000000002600000000dnsproxy-0.75.0/proxy/upstreammode.gopackage proxy

import (
	"encoding"
	"fmt"
)

// UpstreamMode is an enumeration of upstream mode representations.
//
// TODO(d.kolyshev): Set uint8 as underlying type.
type UpstreamMode string

const (
	// UpstreamModeLoadBalance is the default upstream mode.  It balances the
	// upstreams load.
	UpstreamModeLoadBalance UpstreamMode = "load_balance"

	// UpstreamModeParallel makes server to query all configured upstream
	// servers in parallel.
	UpstreamModeParallel UpstreamMode = "parallel"

	// UpstreamModeFastestAddr controls whether the server should respond to A
	// or AAAA requests only with the fastest IP address detected by ICMP
	// response time or TCP connection time.
	UpstreamModeFastestAddr UpstreamMode = "fastest_addr"
)

// type check
var _ encoding.TextUnmarshaler = (*UpstreamMode)(nil)

// UnmarshalText implements [encoding.TextUnmarshaler] interface for
// *UpstreamMode.
func (m *UpstreamMode) UnmarshalText(b []byte) (err error) {
	switch um := UpstreamMode(b); um {
	case
		UpstreamModeLoadBalance,
		UpstreamModeParallel,
		UpstreamModeFastestAddr:
		*m = um
	default:
		return fmt.Errorf(
			"invalid upstream mode %q, supported: %q, %q, %q",
			b,
			UpstreamModeLoadBalance,
			UpstreamModeParallel,
			UpstreamModeFastestAddr,
		)
	}

	return nil
}

// type check
var _ encoding.TextMarshaler = UpstreamMode("")

// MarshalText implements [encoding.TextMarshaler] interface for UpstreamMode.
func (m UpstreamMode) MarshalText() (text []byte, err error) {
	return []byte(m), nil
}
0707010000007D000081A4000000000000000000000001679A649F0000014C000000000000000000000000000000000000002B00000000dnsproxy-0.75.0/proxy/upstreammode_test.gopackage proxy_test

import (
	"testing"

	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/AdguardTeam/golibs/testutil"
)

func TestUpstreamMode_encoding(t *testing.T) {
	t.Parallel()

	v := proxy.UpstreamModeLoadBalance

	testutil.AssertMarshalText(t, "load_balance", &v)
	testutil.AssertUnmarshalText(t, "load_balance", &v)
}
0707010000007E000081A4000000000000000000000001679A649F000038C1000000000000000000000000000000000000002300000000dnsproxy-0.75.0/proxy/upstreams.gopackage proxy

import (
	"fmt"
	"io"
	"log/slog"
	"maps"
	"slices"
	"strings"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/container"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/netutil"
)

// UnqualifiedNames is a key for [UpstreamConfig.DomainReservedUpstreams] map to
// specify the upstreams only used for resolving domain names consisting of a
// single label.
const UnqualifiedNames = "unqualified_names"

// UpstreamConfig maps domain names to upstreams.
type UpstreamConfig struct {
	// DomainReservedUpstreams maps the domains to the upstreams.
	DomainReservedUpstreams map[string][]upstream.Upstream

	// SpecifiedDomainUpstreams maps the specific domain names to the upstreams.
	SpecifiedDomainUpstreams map[string][]upstream.Upstream

	// SubdomainExclusions is set of domains with subdomains exclusions.
	SubdomainExclusions *container.MapSet[string]

	// Upstreams is a list of default upstreams.
	Upstreams []upstream.Upstream
}

// type check
var _ io.Closer = (*UpstreamConfig)(nil)

// ParseUpstreamsConfig returns an UpstreamConfig and nil error if the upstream
// configuration is valid.  Otherwise returns a partially filled UpstreamConfig
// and wrapped error containing lines with errors.  It also skips empty lines
// and comments (lines starting with "#").
//
// # Simple upstreams
//
// Single upstream per line.  For example:
//
//	1.2.3.4
//	3.4.5.6
//
// # Domain specific upstreams
//
//   - reserved upstreams: [/domain1/../domainN/]<upstreamString>
//   - subdomains only upstreams: [/*.domain1/../*.domainN]<upstreamString>
//
// Where <upstreamString> is one or many upstreams separated by space (e.g.
// `1.1.1.1` or `1.1.1.1 2.2.2.2`).
//
// More specific domains take priority over less specific domains.  To exclude
// more specific domains from reserved upstreams querying you should use the
// following syntax:
//
//	[/domain1/../domainN/]#
//
// So the following config:
//
//	[/host.com/]1.2.3.4
//	[/www.host.com/]2.3.4.5"
//	[/maps.host.com/news.host.com/]#
//	3.4.5.6
//
// will send queries for *.host.com to 1.2.3.4.  Except for *.www.host.com,
// which will go to 2.3.4.5.  And *.maps.host.com or *.news.host.com, which
// will go to default server 3.4.5.6 with all other domains.
//
// To exclude top level domain from reserved upstreams querying you could use
// the following:
//
//	'[/*.domain.com/]<upstreamString>'
//
// So the following config:
//
//	[/*.domain.com/]1.2.3.4
//	3.4.5.6
//
// will send queries for all subdomains *.domain.com to 1.2.3.4, but domain.com
// query will be sent to default server 3.4.5.6 as every other query.
//
// TODO(e.burkov):  Consider supporting multiple upstreams in a single line for
// default upstream syntax.
func ParseUpstreamsConfig(
	lines []string,
	opts *upstream.Options,
) (conf *UpstreamConfig, err error) {
	if opts == nil {
		opts = &upstream.Options{}
	}

	if opts.Logger == nil {
		opts.Logger = slog.Default()
	}

	p := &configParser{
		options:                  opts,
		logger:                   opts.Logger,
		upstreamsIndex:           map[string]upstream.Upstream{},
		domainReservedUpstreams:  map[string][]upstream.Upstream{},
		specifiedDomainUpstreams: map[string][]upstream.Upstream{},
		subdomainsOnlyUpstreams:  map[string][]upstream.Upstream{},
		subdomainsOnlyExclusions: container.NewMapSet[string](),
	}

	return p.parse(lines)
}

// ParseError is an error which contains an index of the line of the upstream
// list.
type ParseError struct {
	// err is the original error.
	err error

	// Idx is an index of the lines.  See [ParseUpstreamsConfig].
	Idx int
}

// type check
var _ error = (*ParseError)(nil)

// Error implements the [error] interface for *ParseError.
func (e *ParseError) Error() (msg string) {
	return fmt.Sprintf("parsing error at index %d: %s", e.Idx, e.err)
}

// type check
var _ errors.Wrapper = (*ParseError)(nil)

// Unwrap implements the [errors.Wrapper] interface for *ParseError.
func (e *ParseError) Unwrap() (unwrapped error) { return e.err }

// configParser collects the results of parsing an upstream config.
type configParser struct {
	// options contains upstream properties.
	options *upstream.Options

	// logger is used for logging during parsing.  It's never nil.
	logger *slog.Logger

	// upstreamsIndex is used to avoid creating duplicates of upstreams.
	upstreamsIndex map[string]upstream.Upstream

	// domainReservedUpstreams is a map of reserved domains and lists of
	// corresponding upstreams.
	domainReservedUpstreams map[string][]upstream.Upstream

	// specifiedDomainUpstreams is a map of excluded domains and lists of
	// corresponding upstreams.
	specifiedDomainUpstreams map[string][]upstream.Upstream

	// subdomainsOnlyUpstreams is a map of wildcard subdomains and lists of
	// corresponding upstreams.
	subdomainsOnlyUpstreams map[string][]upstream.Upstream

	// subdomainsOnlyExclusions is set of domains with subdomains exclusions.
	subdomainsOnlyExclusions *container.MapSet[string]

	// upstreams is a list of default upstreams.
	upstreams []upstream.Upstream
}

// parse returns UpstreamConfig and error if upstreams configuration is invalid.
func (p *configParser) parse(lines []string) (c *UpstreamConfig, err error) {
	var errs []error
	for i, l := range lines {
		if err = p.parseLine(i, l); err != nil {
			errs = append(errs, &ParseError{Idx: i, err: err})
		}
	}

	for host, ups := range p.subdomainsOnlyUpstreams {
		// Rewrite ups for wildcard subdomains to remove upper level domains
		// specs.
		p.domainReservedUpstreams[host] = ups
	}

	return &UpstreamConfig{
		Upstreams:                p.upstreams,
		DomainReservedUpstreams:  p.domainReservedUpstreams,
		SpecifiedDomainUpstreams: p.specifiedDomainUpstreams,
		SubdomainExclusions:      p.subdomainsOnlyExclusions,
	}, errors.Join(errs...)
}

// parseLine returns an error if upstream configuration line is invalid.
func (p *configParser) parseLine(idx int, confLine string) (err error) {
	if len(confLine) == 0 || confLine[0] == '#' {
		return nil
	}

	upstreams, domains, err := splitConfigLine(confLine)
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return err
	}

	if upstreams[0] == "#" && len(domains) > 0 {
		p.excludeFromReserved(domains)

		return nil
	}

	for _, u := range upstreams {
		err = p.specifyUpstream(domains, u, idx)
		if err != nil {
			// Don't wrap the error since it's informative enough as is.
			return err
		}
	}

	return nil
}

// splitConfigLine parses upstream configuration line and returns list upstream
// addresses (one or many), list of domains for which this upstream is reserved
// (may be nil).  It returns an error if the upstream format is incorrect.
func splitConfigLine(confLine string) (upstreams, domains []string, err error) {
	if !strings.HasPrefix(confLine, "[/") {
		return []string{confLine}, nil, nil
	}

	domainsLine, upstreamsLine, found := strings.Cut(confLine[len("[/"):], "/]")
	if !found || upstreamsLine == "" {
		return nil, nil, errors.Error("wrong upstream format")
	}

	// split domains list
	for _, confHost := range strings.Split(domainsLine, "/") {
		if confHost == "" {
			// empty domain specification means `unqualified names only`
			domains = append(domains, UnqualifiedNames)

			continue
		}

		host := strings.TrimPrefix(confHost, "*.")
		if err = netutil.ValidateDomainName(host); err != nil {
			return nil, nil, err
		}

		domains = append(domains, strings.ToLower(confHost+"."))
	}

	return strings.Fields(upstreamsLine), domains, nil
}

// specifyUpstream specifies the upstream for domains.
func (p *configParser) specifyUpstream(domains []string, u string, idx int) (err error) {
	dnsUpstream, ok := p.upstreamsIndex[u]
	// TODO(e.burkov):  Improve identifying duplicate upstreams.
	if !ok {
		// create an upstream
		dnsUpstream, err = upstream.AddressToUpstream(u, p.options.Clone())
		if err != nil {
			return fmt.Errorf("cannot prepare the upstream: %s", err)
		}

		// save to the index
		p.upstreamsIndex[u] = dnsUpstream
	}

	addr := dnsUpstream.Address()
	if len(domains) == 0 {
		// TODO(s.chzhen):  Handle duplicates.
		p.upstreams = append(p.upstreams, dnsUpstream)

		// TODO(s.chzhen):  Logs without index.
		p.logger.Debug("set upstream", "idx", idx, "addr", addr)
	} else {
		p.includeToReserved(dnsUpstream, domains)

		p.logger.Debug(
			"upstream is reserved",
			"idx", idx,
			"addr", addr,
			"domains_num", len(domains),
		)
	}

	return nil
}

// excludeFromReserved excludes more specific domains from reserved upstreams
// querying.
func (p *configParser) excludeFromReserved(domains []string) {
	for _, host := range domains {
		if trimmed := strings.TrimPrefix(host, "*."); trimmed != host {
			p.subdomainsOnlyExclusions.Add(trimmed)
			p.subdomainsOnlyUpstreams[trimmed] = nil

			continue
		}

		p.domainReservedUpstreams[host] = nil
		p.specifiedDomainUpstreams[host] = nil
	}
}

// includeToReserved includes domains to reserved upstreams querying.
func (p *configParser) includeToReserved(dnsUpstream upstream.Upstream, domains []string) {
	for _, host := range domains {
		if strings.HasPrefix(host, "*.") {
			host = host[len("*."):]

			p.subdomainsOnlyExclusions.Add(host)
			p.logger.Debug("domain is added to exclusions list", "domain", host)

			p.subdomainsOnlyUpstreams[host] = append(p.subdomainsOnlyUpstreams[host], dnsUpstream)
		} else {
			p.specifiedDomainUpstreams[host] = append(p.specifiedDomainUpstreams[host], dnsUpstream)
		}

		p.domainReservedUpstreams[host] = append(p.domainReservedUpstreams[host], dnsUpstream)
	}
}

// validate returns an error if the upstreams aren't configured properly.  c
// considered valid if it contains at least a single default upstream.  Empty c
// causes [upstream.ErrNoUpstreams].
func (uc *UpstreamConfig) validate() (err error) {
	const (
		errNilConf   errors.Error = "upstream config is nil"
		errNoDefault errors.Error = "no default upstreams specified"
	)

	switch {
	case uc == nil:
		return errNilConf
	case len(uc.Upstreams) > 0:
		return nil
	case len(uc.DomainReservedUpstreams) == 0 && len(uc.SpecifiedDomainUpstreams) == 0:
		return upstream.ErrNoUpstreams
	default:
		return errNoDefault
	}
}

// ValidatePrivateConfig returns an error if uc isn't valid, or, treated as
// private upstreams configuration, contains specifications for invalid domains.
func ValidatePrivateConfig(uc *UpstreamConfig, privateSubnets netutil.SubnetSet) (err error) {
	if err = uc.validate(); err != nil {
		// Don't wrap the error since it's informative enough as is.
		return err
	}

	var errs []error
	for _, domain := range slices.Sorted(maps.Keys(uc.DomainReservedUpstreams)) {
		pref, extErr := netutil.ExtractReversedAddr(domain)
		switch {
		case extErr != nil:
			// Don't wrap the error since it's informative enough as is.
			errs = append(errs, extErr)
		case pref.Bits() == 0:
			// Allow private subnets for subdomains of the root domain.
		case !privateSubnets.Contains(pref.Addr()):
			errs = append(errs, fmt.Errorf("reversed subnet in %q is not private", domain))
		default:
			// Go on.
		}
	}

	return errors.Join(errs...)
}

// getUpstreamsForDomain returns the upstreams specified for resolving fqdn.  It
// always returns the default set of upstreams if the domain is not reserved for
// any other upstreams.
//
// More specific domains take priority over less specific ones.  For example, if
// the upstreams specified for the following domains:
//
//   - host.com
//   - www.host.com
//
// The request for mail.host.com will be resolved using the upstreams specified
// for host.com.
func (uc *UpstreamConfig) getUpstreamsForDomain(fqdn string) (ups []upstream.Upstream) {
	if len(uc.DomainReservedUpstreams) == 0 {
		return uc.Upstreams
	}

	fqdn = strings.ToLower(fqdn)
	if uc.SubdomainExclusions.Has(fqdn) {
		return uc.lookupSubdomainExclusion(fqdn)
	}

	ups, ok := uc.lookupUpstreams(fqdn)
	if ok {
		return ups
	}

	if _, fqdn, _ = strings.Cut(fqdn, "."); fqdn == "" {
		fqdn = UnqualifiedNames
	}

	for fqdn != "" {
		if ups, ok = uc.lookupUpstreams(fqdn); ok {
			return ups
		}

		_, fqdn, _ = strings.Cut(fqdn, ".")
	}

	return uc.Upstreams
}

// getUpstreamsForDS is like [getUpstreamsForDomain], but intended for DS
// queries only, so that it matches fqdn without the first label.
//
// A DS RRset SHOULD be present at a delegation point when the child zone is
// signed.  The DS RRset MAY contain multiple records, each referencing a public
// key in the child zone used to verify the RRSIGs in that zone.  All DS RRsets
// in a zone MUST be signed, and DS RRsets MUST NOT appear at a zone's apex.
//
// See https://datatracker.ietf.org/doc/html/rfc4035#section-2.4
func (uc *UpstreamConfig) getUpstreamsForDS(fqdn string) (ups []upstream.Upstream) {
	_, fqdn, _ = strings.Cut(fqdn, ".")
	if fqdn == "" {
		return uc.Upstreams
	}

	return uc.getUpstreamsForDomain(fqdn)
}

// lookupSubdomainExclusion returns upstreams for the host from subdomain
// exclusions list.
func (uc *UpstreamConfig) lookupSubdomainExclusion(host string) (u []upstream.Upstream) {
	ups, ok := uc.SpecifiedDomainUpstreams[host]
	if ok && len(ups) > 0 {
		return ups
	}

	// Check if there is a spec for upper level domain.
	h := strings.SplitAfterN(host, ".", 2)
	ups, ok = uc.DomainReservedUpstreams[h[1]]
	if ok && len(ups) > 0 {
		return ups
	}

	return uc.Upstreams
}

// lookupUpstreams returns upstreams for a domain name.  It returns default
// upstream list for domain name excluded by domain reserved upstreams.
func (uc *UpstreamConfig) lookupUpstreams(name string) (ups []upstream.Upstream, ok bool) {
	ups, ok = uc.DomainReservedUpstreams[name]
	if !ok {
		return ups, false
	}

	if len(ups) == 0 {
		// The domain has been excluded from reserved upstreams querying.
		ups = uc.Upstreams
	}

	return ups, true
}

// Close implements the io.Closer interface for *UpstreamConfig.
func (uc *UpstreamConfig) Close() (err error) {
	closeErrs := closeAll(nil, uc.Upstreams...)

	for _, specUps := range []map[string][]upstream.Upstream{
		uc.DomainReservedUpstreams,
		uc.SpecifiedDomainUpstreams,
	} {
		domains := make([]string, 0, len(specUps))
		for domain := range specUps {
			domains = append(domains, domain)
		}

		slices.SortStableFunc(domains, strings.Compare)

		for _, domain := range domains {
			closeErrs = closeAll(closeErrs, specUps[domain]...)
		}
	}

	if len(closeErrs) > 0 {
		return fmt.Errorf("failed to close some upstreams: %w", errors.Join(closeErrs...))
	}

	return nil
}
0707010000007F000081A4000000000000000000000001679A649F00002D00000000000000000000000000000000000000003100000000dnsproxy-0.75.0/proxy/upstreams_internal_test.gopackage proxy

import (
	"testing"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// TODO(e.burkov):  Call [testing.T.Parallel] in this file.

// Domains specifications and their questions used in tests of [UpstreamConfig].
const (
	unqualifiedFQDN = "unqualified."
	unspecifiedFQDN = "unspecified.domain."

	topLevelDomain = "example"
	topLevelFQDN   = topLevelDomain + "."

	firstLevelDomain         = "name." + topLevelDomain
	firstLevelFQDN           = firstLevelDomain + "."
	wildcardFirstLevelDomain = "*." + topLevelDomain

	subDomain = "sub." + firstLevelDomain
	subFQDN   = subDomain + "."

	generalDomain = "general." + firstLevelDomain
	generalFQDN   = generalDomain + "."

	wildcardDomain = "*." + firstLevelDomain
	anotherSubFQDN = "another." + firstLevelDomain + "."
)

// Upstream URLs used in tests of [UpstreamConfig].
const (
	generalUpstream     = "tcp://general.upstream:53"
	unqualifiedUpstream = "tcp://unqualified.upstream:53"
	tldUpstream         = "tcp://tld.upstream:53"
	domainUpstream      = "tcp://domain.upstream:53"
	wildcardUpstream    = "tcp://wildcard.upstream:53"
	subdomainUpstream   = "tcp://subdomain.upstream:53"
)

// testUpstreamConfigLines is the common set of upstream configurations used in
// tests of [UpstreamConfig].
var testUpstreamConfigLines = []string{
	generalUpstream,
	"[//]" + unqualifiedUpstream,
	"[/" + topLevelDomain + "/]" + tldUpstream,
	"[/" + wildcardFirstLevelDomain + "/]#",
	"[/" + firstLevelDomain + "/]" + domainUpstream,
	"[/" + wildcardDomain + "/]" + wildcardUpstream,
	"[/" + generalDomain + "/]#",
	"[/" + subDomain + "/]" + subdomainUpstream,
}

func TestUpstreamConfig_GetUpstreamsForDomain(t *testing.T) {
	t.Parallel()

	config, err := ParseUpstreamsConfig(testUpstreamConfigLines, nil)
	require.NoError(t, err)

	testCases := []struct {
		name string
		in   string
		want []string
	}{{
		name: "unspecified",
		in:   unspecifiedFQDN,
		want: []string{generalUpstream},
	}, {
		name: "unqualified",
		in:   unqualifiedFQDN,
		want: []string{unqualifiedUpstream},
	}, {
		name: "tld",
		in:   topLevelFQDN,
		want: []string{tldUpstream},
	}, {
		name: "unspecified_subdomain",
		in:   unspecifiedFQDN + topLevelFQDN,
		want: []string{generalUpstream},
	}, {
		name: "domain",
		in:   firstLevelFQDN,
		want: []string{domainUpstream},
	}, {
		name: "wildcard",
		in:   anotherSubFQDN,
		want: []string{wildcardUpstream},
	}, {
		name: "general",
		in:   generalFQDN,
		want: []string{generalUpstream},
	}, {
		name: "subdomain",
		in:   subFQDN,
		want: []string{subdomainUpstream},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			ups := config.getUpstreamsForDomain(tc.in)
			assertUpstreamsAddrs(t, ups, tc.want)
		})
	}
}

func TestUpstreamConfig_GetUpstreamsForDS(t *testing.T) {
	t.Parallel()

	config, err := ParseUpstreamsConfig(testUpstreamConfigLines, nil)
	require.NoError(t, err)

	testCases := []struct {
		name string
		in   string
		want []string
	}{{
		name: "unspecified",
		in:   unspecifiedFQDN,
		want: []string{unqualifiedUpstream},
	}, {
		name: "unqualified",
		in:   unqualifiedFQDN,
		want: []string{generalUpstream},
	}, {
		name: "tld",
		in:   topLevelFQDN,
		want: []string{generalUpstream},
	}, {
		name: "unspecified_subdomain",
		in:   unspecifiedFQDN + topLevelFQDN,
		want: []string{generalUpstream},
	}, {
		name: "domain",
		in:   firstLevelFQDN,
		want: []string{tldUpstream},
	}, {
		name: "wildcard",
		in:   anotherSubFQDN,
		want: []string{domainUpstream},
	}, {
		name: "general",
		in:   "label." + generalFQDN,
		want: []string{generalUpstream},
	}, {
		name: "subdomain",
		in:   "label." + subFQDN,
		want: []string{subdomainUpstream},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			ups := config.getUpstreamsForDS(tc.in)
			assertUpstreamsAddrs(t, ups, tc.want)
		})
	}
}

func TestUpstreamConfig_Validate(t *testing.T) {
	testCases := []struct {
		name    string
		wantErr error
		in      []string
	}{{
		name:    "empty",
		wantErr: upstream.ErrNoUpstreams,
		in:      []string{},
	}, {
		name:    "nil",
		wantErr: upstream.ErrNoUpstreams,
		in:      nil,
	}, {
		name:    "valid",
		wantErr: nil,
		in: []string{
			"udp://upstream.example:53",
		},
	}, {
		name:    "no_default",
		wantErr: errors.Error("no default upstreams specified"),
		in: []string{
			"[/domain.example/]udp://upstream.example:53",
			"[/another.domain.example/]#",
		},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			c, err := ParseUpstreamsConfig(tc.in, nil)
			require.NoError(t, err)

			assert.ErrorIs(t, c.validate(), tc.wantErr)
		})
	}

	t.Run("actual_nil", func(t *testing.T) {
		assert.ErrorIs(t, (*UpstreamConfig)(nil).validate(), errors.Error("upstream config is nil"))
	})
}

func TestValidatePrivateConfig(t *testing.T) {
	ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)

	testCases := []struct {
		name    string
		wantErr string
		u       string
	}{{
		name:    "success_address",
		wantErr: ``,
		u:       "[/1.0.0.127.in-addr.arpa/]#",
	}, {
		name:    "success_subnet",
		wantErr: ``,
		u:       "[/127.in-addr.arpa/]#",
	}, {
		name:    "success_v4_family",
		wantErr: ``,
		u:       "[/in-addr.arpa/]#",
	}, {
		name:    "success_v6_family",
		wantErr: ``,
		u:       "[/ip6.arpa/]#",
	}, {
		name:    "bad_arpa_domain",
		wantErr: `bad arpa domain name "arpa": not a reversed ip network`,
		u:       "[/arpa/]#",
	}, {
		name:    "not_arpa_subnet",
		wantErr: `bad arpa domain name "hello.world": not a reversed ip network`,
		u:       "[/hello.world/]#",
	}, {
		name:    "non-private_arpa_address",
		wantErr: `reversed subnet in "1.2.3.4.in-addr.arpa." is not private`,
		u:       "[/1.2.3.4.in-addr.arpa/]#",
	}, {
		name:    "non-private_arpa_subnet",
		wantErr: `reversed subnet in "128.in-addr.arpa." is not private`,
		u:       "[/128.in-addr.arpa/]#",
	}, {
		name: "several_bad",
		wantErr: `reversed subnet in "1.2.3.4.in-addr.arpa." is not private` +
			"\n" + `bad arpa domain name "non.arpa": not a reversed ip network`,
		u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
	}, {
		name:    "partial_good",
		wantErr: "",
		u:       "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
	}}

	for _, tc := range testCases {
		set := []string{"192.168.0.1", tc.u}

		t.Run(tc.name, func(t *testing.T) {
			upsConf, err := ParseUpstreamsConfig(set, nil)
			require.NoError(t, err)

			testutil.AssertErrorMsg(t, tc.wantErr, ValidatePrivateConfig(upsConf, ss))
		})
	}
}

func TestGetUpstreamsForDomainWithoutDuplicates(t *testing.T) {
	upstreams := []string{"[/example.com/]1.1.1.1", "[/example.org/]1.1.1.1"}
	config, err := ParseUpstreamsConfig(upstreams, &upstream.Options{
		Logger:             slogutil.NewDiscardLogger(),
		InsecureSkipVerify: false,
		Bootstrap:          nil,
		Timeout:            testTimeout,
	})
	assert.NoError(t, err)
	assert.Len(t, config.Upstreams, 0)
	assert.Len(t, config.DomainReservedUpstreams, 2)

	u1 := config.DomainReservedUpstreams["example.com."][0]
	u2 := config.DomainReservedUpstreams["example.org."][0]

	// Check that the very same Upstream instance is used for both domains.
	assert.Same(t, u1, u2)
}

func TestGetUpstreamsForDomain_wildcards(t *testing.T) {
	conf := []string{
		"0.0.0.1",
		"[/a.x/]0.0.0.2",
		"[/*.a.x/]0.0.0.3",
		"[/b.a.x/]0.0.0.4",
		"[/*.b.a.x/]0.0.0.5",
		"[/*.x.z/]0.0.0.6",
		"[/c.b.a.x/]#",
	}

	uconf, err := ParseUpstreamsConfig(conf, nil)
	require.NoError(t, err)

	testCases := []struct {
		name string
		in   string
		want []string
	}{{
		name: "default",
		in:   "d.x.",
		want: []string{"0.0.0.1:53"},
	}, {
		name: "specified_one",
		in:   "a.x.",
		want: []string{"0.0.0.2:53"},
	}, {
		name: "wildcard",
		in:   "c.a.x.",
		want: []string{"0.0.0.3:53"},
	}, {
		name: "specified_two",
		in:   "b.a.x.",
		want: []string{"0.0.0.4:53"},
	}, {
		name: "wildcard_two",
		in:   "d.b.a.x.",
		want: []string{"0.0.0.5:53"},
	}, {
		name: "specified_three",
		in:   "c.b.a.x.",
		want: []string{"0.0.0.1:53"},
	}, {
		name: "specified_four",
		in:   "d.c.b.a.x.",
		want: []string{"0.0.0.1:53"},
	}, {
		name: "unspecified_wildcard",
		in:   "x.z.",
		want: []string{"0.0.0.1:53"},
	}, {
		name: "unspecified_wildcard_sub",
		in:   "a.x.z.",
		want: []string{"0.0.0.6:53"},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			ups := uconf.getUpstreamsForDomain(tc.in)
			assertUpstreamsAddrs(t, ups, tc.want)
		})
	}
}

func TestGetUpstreamsForDomain_sub_wildcards(t *testing.T) {
	conf := []string{
		"0.0.0.1",
		"[/a.x/]0.0.0.2",
		"[/*.a.x/]0.0.0.3",
		"[/*.b.a.x/]0.0.0.5",
	}

	uconf, err := ParseUpstreamsConfig(conf, nil)
	require.NoError(t, err)

	testCases := []struct {
		name string
		in   string
		want []string
	}{{
		name: "specified",
		in:   "a.x.",
		want: []string{"0.0.0.2:53"},
	}, {
		name: "wildcard",
		in:   "c.a.x.",
		want: []string{"0.0.0.3:53"},
	}, {
		name: "sub_spec_ignore",
		in:   "b.a.x.",
		want: []string{"0.0.0.3:53"},
	}, {
		name: "sub_spec_wildcard",
		in:   "d.b.a.x.",
		want: []string{"0.0.0.5:53"},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			ups := uconf.getUpstreamsForDomain(tc.in)
			assertUpstreamsAddrs(t, ups, tc.want)
		})
	}
}

func TestGetUpstreamsForDomain_default_wildcards(t *testing.T) {
	conf := []string{
		"127.0.0.1:5301",
		"[/example.org/]127.0.0.1:5302",
		"[/*.example.org/]127.0.0.1:5303",
		"[/www.example.org/]127.0.0.1:5304",
		"[/*.www.example.org/]#",
	}

	uconf, err := ParseUpstreamsConfig(conf, nil)
	require.NoError(t, err)

	testCases := []struct {
		name string
		in   string
		want []string
	}{{
		name: "domain",
		in:   "example.org.",
		want: []string{"127.0.0.1:5302"},
	}, {
		name: "sub_wildcard",
		in:   "sub.example.org.",
		want: []string{"127.0.0.1:5303"},
	}, {
		name: "spec_sub",
		in:   "www.example.org.",
		want: []string{"127.0.0.1:5304"},
	}, {
		name: "def_wildcard",
		in:   "abc.www.example.org.",
		want: []string{"127.0.0.1:5301"},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			ups := uconf.getUpstreamsForDomain(tc.in)
			assertUpstreamsAddrs(t, ups, tc.want)
		})
	}
}

// upsSink is the typed sink variable for the result of benchmarked function.
var upsSink []upstream.Upstream

func BenchmarkGetUpstreamsForDomain(b *testing.B) {
	upstreams := []string{
		"[/google.com/local/]4.3.2.1",
		"[/www.google.com//]1.2.3.4",
		"[/maps.google.com/]#",
		"[/www.google.com/]tls://1.1.1.1",
	}

	config, _ := ParseUpstreamsConfig(upstreams, &upstream.Options{
		Logger:             slogutil.NewDiscardLogger(),
		InsecureSkipVerify: false,
		Bootstrap:          nil,
		Timeout:            testTimeout,
	})

	domains := []string{
		"www.google.com.",
		"www2.google.com.",
		"internal.local.",
		"google.",
		"maps.google.com.",
	}

	l := len(domains)
	for i := range b.N {
		upsSink = config.getUpstreamsForDomain(domains[i%l])
	}
}

// assertUpstreamsAddrs checks the addresses of ups to exactly match want.
func assertUpstreamsAddrs(tb testing.TB, ups []upstream.Upstream, want []string) {
	tb.Helper()

	require.Len(tb, ups, len(want))
	for i, up := range ups {
		assert.Equalf(tb, want[i], up.Address(), "at index %d", i)
	}
}
07070100000080000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001A00000000dnsproxy-0.75.0/proxyutil07070100000081000081A4000000000000000000000001679A649F000002A8000000000000000000000000000000000000002100000000dnsproxy-0.75.0/proxyutil/dns.go// Package proxyutil contains helper functions that are used in all other
// dnsproxy packages.
package proxyutil

import (
	"encoding/binary"
	"net/netip"

	"github.com/miekg/dns"
)

// AddPrefix adds a 2-byte prefix with the DNS message length.
func AddPrefix(b []byte) (m []byte) {
	m = make([]byte, 2+len(b))
	binary.BigEndian.PutUint16(m, uint16(len(b)))
	copy(m[2:], b)

	return m
}

// IPFromRR returns the IP address from rr if any.
func IPFromRR(rr dns.RR) (ip netip.Addr) {
	var data []byte
	switch rr := rr.(type) {
	case *dns.A:
		data = rr.A.To4()
	case *dns.AAAA:
		data = rr.AAAA
	default:
		return netip.Addr{}
	}

	ip, _ = netip.AddrFromSlice(data)

	return ip
}
07070100000082000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001800000000dnsproxy-0.75.0/scripts07070100000083000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001E00000000dnsproxy-0.75.0/scripts/hooks07070100000084000081ED000000000000000000000001679A649F00000867000000000000000000000000000000000000002900000000dnsproxy-0.75.0/scripts/hooks/pre-commit#!/bin/sh

set -e -f -u

# This comment is used to simplify checking local copies of the script.
# Bump this number every time a significant change is made to this
# script.
#
# AdGuard-Project-Version: 2

# TODO(a.garipov): Add pre-merge-commit.

# Only show interactive prompts if there a terminal is attached to
# stdout.  While this technically doesn't guarantee that reading from
# /dev/tty works, this should work reasonably well on all of our
# supported development systems and in most terminal emulators.
is_tty='0'
if [ -t '1' ]
then
	is_tty='1'
fi
readonly is_tty

# prompt is a helper that prompts the user for interactive input if that
# can be done.  If there is no terminal attached, it sleeps for two
# seconds, giving the programmer some time to react, and returns with
# a zero exit code.
prompt() {
	if [ "$is_tty" -eq '0' ]
	then
		sleep 2

		return 0
	fi

	while true
	do
		printf 'commit anyway? y/[n]: '
		read -r ans < /dev/tty

		case "$ans"
		in
		('y'|'Y')
			break
			;;
		(''|'n'|'N')
			exit 1
			;;
		(*)
			continue
			;;
		esac
	done
}

# Warn the programmer about unstaged changes and untracked files, but do
# not fail the commit, because those changes might be temporary or for
# a different branch.
awk_prog='substr($2, 2, 1) != "." { print $9; } $1 == "?" { print $2; }'
readonly awk_prog

unstaged="$( git status --porcelain=2 | awk "$awk_prog" )"
readonly unstaged

if [ "$unstaged" != "" ]
then
	printf 'WARNING: you have unstaged changes:\n\n%s\n\n' "$unstaged"
	prompt
fi

# Warn the programmer about temporary todos, but do not fail the commit,
# because the commit could be in a temporary branch.
temp_todos="$( git grep -e 'TODO.*!!' -- ':!scripts/hooks/pre-commit' || : )"
readonly temp_todos

if [ "$temp_todos" != "" ]
then
	printf 'WARNING: you have temporary todos:\n\n%s\n\n' "$temp_todos"
	prompt
fi

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$( git diff --cached --name-only -- '*.md' '*.yaml' '*.yml' )" ]
then
	make VERBOSE="$verbose" txt-lint
fi

if [ "$( git diff --cached --name-only -- '*.go' '*.mod' '*.sh' 'Makefile' )" ]
then
	make VERBOSE="$verbose" go-os-check go-lint go-test
fi
07070100000085000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001D00000000dnsproxy-0.75.0/scripts/make07070100000086000081A4000000000000000000000001679A649F00000BC4000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/scripts/make/build-docker.sh#!/bin/sh

verbose="${VERBOSE:-0}"

if [ "$verbose" -gt '0' ]
then
	set -x
	debug_flags='--debug=1'
else
	set +x
	debug_flags='--debug=0'
fi
readonly debug_flags

set -e -f -u

# Require these to be set.
commit="${REVISION:?please set REVISION}"
dist_dir="${DIST_DIR:?please set DIST_DIR}"
version="${VERSION:?please set VERSION}"
readonly commit dist_dir version

# Allow users to use sudo.
sudo_cmd="${SUDO:-}"
readonly sudo_cmd

docker_platforms="\
linux/386,\
linux/amd64,\
linux/arm/v6,\
linux/arm/v7,\
linux/arm64,\
linux/ppc64le"
readonly docker_platforms

build_date="$( date -u +'%Y-%m-%dT%H:%M:%SZ' )"
readonly build_date

# Set DOCKER_IMAGE_NAME to 'adguard/dnsproxy' if you want (and are allowed)
# to push to DockerHub.
docker_image_name="${DOCKER_IMAGE_NAME:-dnsproxy-dev}"
readonly docker_image_name

# Set DOCKER_OUTPUT to 'type=image,name=adguard/dnsproxy,push=true' if you
# want (and are allowed) to push to DockerHub.
#
# If you want to inspect the resulting image using commands like "docker image
# ls", change type to docker and also set docker_platforms to a single platform.
#
# See https://github.com/docker/buildx/issues/166.
docker_output="${DOCKER_OUTPUT:-type=image,name=${docker_image_name},push=false}"
readonly docker_output

docker_version_tag="--tag=${docker_image_name}:${version}"
docker_channel_tag="--tag=${docker_image_name}:latest"

# If version is set to 'dev' or empty, only set the version tag and avoid
# polluting the "latest" tag.
if [ "${version:-}" = 'dev' ] || [ "${version:-}" = '' ]
then
  docker_channel_tag=""
fi

readonly docker_version_tag docker_channel_tag

# Copy the binaries into a new directory under new names, so that it's easier to
# COPY them later.  DO NOT remove the trailing underscores.  See file
# docker/Dockerfile.
dist_docker="${dist_dir}/docker"
readonly dist_docker

mkdir -p "$dist_docker"
cp "${dist_dir}/linux-386/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_386_"
cp "${dist_dir}/linux-amd64/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_amd64_"
cp "${dist_dir}/linux-arm64/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_arm64_"
cp "${dist_dir}/linux-arm6/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_arm_v6"
cp "${dist_dir}/linux-arm7/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_arm_v7"
cp "${dist_dir}/linux-ppc64le/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_ppc64le_"

# Prepare the default configuration for the Docker image.
cp ./config.yaml.dist "${dist_docker}/config.yaml"

# Don't use quotes with $docker_version_tag and $docker_channel_tag, because we
# want word splitting and or an empty space if tags are empty.
#
# TODO(a.garipov): Once flag --tag of docker buildx build supports commas, use
# them instead.
$sudo_cmd docker\
	"$debug_flags"\
	buildx build\
	--build-arg BUILD_DATE="$build_date"\
	--build-arg DIST_DIR="$dist_dir"\
	--build-arg VCS_REF="$commit"\
	--build-arg VERSION="$version"\
	--output "$docker_output"\
	--platform "$docker_platforms"\
	$docker_version_tag\
	$docker_channel_tag\
	-f ./docker/Dockerfile\
	.
07070100000087000081A4000000000000000000000001679A649F00000CEE000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/scripts/make/build-release.sh#!/bin/sh

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '2' ]
then
	env
	set -x
elif [ "$verbose" -gt '1' ]
then
	set -x
fi

set -e -f -u

log() {
	if [ "$verbose" -gt '0' ]
	then
		# Don't use quotes to get word splitting.
		echo "$1" 1>&2
	fi
}

log 'starting to build dnsproxy release'

version="${VERSION:-}"
readonly version

log "version '$version'"

dist="${DIST_DIR:-build}"
readonly dist

out="${OUT:-dnsproxy}"

log "checking tools"

for tool in tar zip
do
	if ! command -v "$tool" > /dev/null
	then
		log "tool '$tool' not found"

		exit 1
	fi
done

# Data section.  Arrange data into space-separated tables for read -r to read.
# Use 0 for missing values.

#    os  arch      arm mips
platforms="\
darwin   amd64     0   0
darwin   arm64     0   0
freebsd  386       0   0
freebsd  amd64     0   0
freebsd  arm       5   0
freebsd  arm       6   0
freebsd  arm       7   0
freebsd  arm64     0   0
linux    386       0   0
linux    amd64     0   0
linux    arm       5   0
linux    arm       6   0
linux    arm       7   0
linux    arm64     0   0
linux    mips      0   softfloat
linux    mips64    0   softfloat
linux    mips64le  0   softfloat
linux    mipsle    0   softfloat
linux    ppc64le   0   0
openbsd  amd64     0   0
openbsd  arm64     0   0
windows  386       0   0
windows  amd64     0   0
windows  arm64     0   0"
readonly platforms

build() {
	# Get the arguments.  Here and below, use the "build_" prefix for all
	# variables local to function build.
	build_dir="${dist}/${1}"\
		build_name="$1"\
		build_os="$2"\
		build_arch="$3"\
		build_arm="$4"\
		build_mips="$5"\
		;

	# Use the ".exe" filename extension if we build a Windows release.
	if [ "$build_os" = 'windows' ]
	then
		build_output="./${build_dir}/${out}.exe"
	else
		build_output="./${build_dir}/${out}"
	fi

	mkdir -p "./${build_dir}"

	# Build the binary.
	#
	# Set GOARM and GOMIPS to an empty string if $build_arm and $build_mips
	# are zero by removing the zero as if it's a prefix.
	#
	# Don't use quotes with $build_par because we want an empty space if
	# parallelism wasn't set.
	env\
		GOARCH="$build_arch"\
		GOARM="${build_arm#0}"\
		GOMIPS="${build_mips#0}"\
		GOOS="$os"\
		VERBOSE="$(( verbose - 1 ))"\
		VERSION="$version"\
		OUT="$build_output"\
		sh ./scripts/make/go-build.sh\
		;

	log "$build_output"

	# Prepare the build directory for archiving.
	cp ./LICENSE ./README.md "$build_dir"

	# Make archives.  Windows prefers ZIP archives; the rest, gzipped tarballs.
	case "$build_os"
	in
	('windows')
		build_archive="./${dist}/${out}-${build_name}-${version}.zip"
		# TODO(a.garipov): Find an option similar to the -C option of tar for
		# zip.
		( cd "${dist}" && zip -9 -q -r "../${build_archive}" "./${build_name}" )
		;;
	(*)
		build_archive="./${dist}/${out}-${build_name}-${version}.tar.gz"
		tar -C "./${dist}" -c -f - "./${build_name}" | gzip -9 - > "$build_archive"
		;;
	esac

	log "$build_archive"
}

log "starting builds"

# Go over all platforms defined in the space-separated table above, tweak the
# values where necessary, and feed to build.
echo "$platforms" | while read -r os arch arm mips
do
	case "$arch"
	in
	(arm)
		name="${os}-${arch}${arm}"
		;;
	(*)
		name="${os}-${arch}"
		;;
	esac

	build "$name" "$os" "$arch" "$arm" "$mips"
done

log "finished"
07070100000088000081A4000000000000000000000001679A649F00000B62000000000000000000000000000000000000002900000000dnsproxy-0.75.0/scripts/make/go-build.sh#!/bin/sh

# dnsproxy build script
#
# The commentary in this file is written with the assumption that the reader
# only has superficial knowledge of the POSIX shell language and alike.
# Experienced readers may find it overly verbose.

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 1

# The default verbosity level is 0.  Show every command that is run and every
# package that is processed if the caller requested verbosity level greater than
# 0.  Also show subcommands if the requested verbosity level is greater than 1.
# Otherwise, do nothing.
verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '1' ]
then
	env
	set -x
	v_flags='-v=1'
	x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=0'
else
	set +x
	v_flags='-v=0'
	x_flags='-x=0'
fi
readonly x_flags v_flags

# Exit the script if a pipeline fails (-e), prevent accidental filename
# expansion (-f), and consider undefined variables as errors (-u).
set -e -f -u

# Allow users to override the go command from environment.  For example, to
# build two releases with two different Go versions and test the difference.
go="${GO:-go}"
readonly go

# Set the build parameters unless already set.
branch="${BRANCH:-$( git rev-parse --abbrev-ref HEAD )}"
revision="${REVISION:-$( git rev-parse --short HEAD )}"
version="${VERSION:-0}"
readonly branch revision version

# Set date and time of the latest commit unless already set.
committime="${SOURCE_DATE_EPOCH:-$( git log -1 --pretty=%ct )}"
readonly committime

# Compile them in.
version_pkg='github.com/AdguardTeam/dnsproxy/internal/version'
ldflags="-s -w"
ldflags="${ldflags} -X ${version_pkg}.branch=${branch}"
ldflags="${ldflags} -X ${version_pkg}.committime=${committime}"
ldflags="${ldflags} -X ${version_pkg}.revision=${revision}"
ldflags="${ldflags} -X ${version_pkg}.version=${version}"
readonly ldflags version_pkg

# Allow users to limit the build's parallelism.
parallelism="${PARALLELISM:-}"
readonly parallelism

# Use GOFLAGS for -p, because -p=0 simply disables the build instead of leaving
# the default value.
if [ "${parallelism}" != '' ]
then
        GOFLAGS="${GOFLAGS:-} -p=${parallelism}"
fi
readonly GOFLAGS
export GOFLAGS

# Allow users to specify a different output name.
out="${OUT:-dnsproxy}"
readonly out

o_flags="-o=${out}"
readonly o_flags

# Allow users to enable the race detector.  Unfortunately, that means that cgo
# must be enabled.
if [ "${RACE:-0}" -eq '0' ]
then
	CGO_ENABLED='0'
	race_flags='--race=0'
else
	CGO_ENABLED='1'
	race_flags='--race=1'
fi
readonly CGO_ENABLED race_flags
export CGO_ENABLED

GO111MODULE='on'
export GO111MODULE

if [ "$verbose" -gt '0' ]
then
	"$go" env
fi

"$go" build\
	--ldflags="$ldflags"\
	"$race_flags"\
	--trimpath\
	"$o_flags"\
	"$v_flags"\
	"$x_flags"
07070100000089000081A4000000000000000000000001679A649F000001D8000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/go-deps.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 1

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '1' ]
then
	env
	set -x
	x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
	set -x
	x_flags='-x=0'
else
	set +x
	x_flags='-x=0'
fi
readonly x_flags

set -e -f -u

go="${GO:-go}"
readonly go

"$go" mod download "$x_flags"
0707010000008A000081A4000000000000000000000001679A649F0000112C000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/go-lint.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 7

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '0' ]
then
	set -x
fi

# Set $EXIT_ON_ERROR to zero to see all errors.
if [ "${EXIT_ON_ERROR:-1}" -eq '0' ]
then
	set +e
else
	set -e
fi

set -f -u



# Source the common helpers, including not_found and run_linter.
. ./scripts/make/helper.sh



# Simple analyzers

# blocklist_imports is a simple check against unwanted packages.  The following
# packages are banned:
#
#   *  Package errors is replaced by our own package in the
#      github.com/AdguardTeam/golibs module.
#
#   *  Package io/ioutil is soft-deprecated.
#
#   *  Package log and github.com/AdguardTeam/golibs/log are replaced by
#      stdlib's new package log/slog and AdGuard's new utilities package
#      github.com/AdguardTeam/golibs/logutil/slogutil.
#
#   *  Package reflect is often an overkill, and for deep comparisons there are
#      much better functions in module github.com/google/go-cmp.  Which is
#      already our indirect dependency and which may or may not enter the stdlib
#      at some point.
#
#      See https://github.com/golang/go/issues/45200.
#
#   *  Package sort is replaced by package slices.
#
#   *  Package unsafe is… unsafe.
#
#   *  Package golang.org/x/exp/slices has been moved into stdlib.
#
#   *  Package golang.org/x/net/context has been moved into stdlib.
#
# Currently, the only standard exception are files generated from protobuf
# schemas, which use package reflect.  If your project needs more exceptions,
# add and document them.
#
# TODO(a.garipov): Add deprecated package golang.org/x/exp/maps once all
# projects switch to Go 1.23.
blocklist_imports() {
	git grep\
		-e '[[:space:]]"errors"$'\
		-e '[[:space:]]"github.com/AdguardTeam/golibs/log"$'\
		-e '[[:space:]]"golang.org/x/exp/slices"$'\
		-e '[[:space:]]"golang.org/x/net/context"$'\
		-e '[[:space:]]"io/ioutil"$'\
		-e '[[:space:]]"log"$'\
		-e '[[:space:]]"reflect"$'\
		-e '[[:space:]]"sort"$'\
		-e '[[:space:]]"unsafe"$'\
		-n\
		-- '*.go'\
		':!*.pb.go'\
		| sed -e 's/^\([^[:space:]]\+\)\(.*\)$/\1 blocked import:\2/'\
		|| exit 0
}

# method_const is a simple check against the usage of some raw strings and
# numbers where one should use named constants.
method_const() {
	git grep -F\
		-e '"DELETE"'\
		-e '"GET"'\
		-e '"PATCH"'\
		-e '"POST"'\
		-e '"PUT"'\
		-n\
		-- '*.go'\
		| sed -e 's/^\([^[:space:]]\+\)\(.*\)$/\1 http method literal:\2/'\
		|| exit 0
}

# underscores is a simple check against Go filenames with underscores.  Add new
# build tags and OS as you go.  The main goal of this check is to discourage the
# use of filenames like client_manager.go.
underscores() {
	underscore_files="$(
		git ls-files '*_*.go'\
			| grep -F\
			-e '_darwin.go'\
			-e '_generate.go'\
			-e '_linux.go'\
			-e '_others.go'\
			-e '_plan9.go'\
			-e '_test.go'\
			-e '_unix.go'\
			-e '_windows.go'\
			-e '_dnscrypt.go'\
			-e '_https.go'\
			-e '_quic.go'\
			-e '_tcp.go'\
			-e '_udp.go'\
			-v\
			| sed -e 's/./\t\0/'
	)"
	readonly underscore_files

	if [ "$underscore_files" != '' ]
	then
		echo 'found file names with underscores:'
		echo "$underscore_files"
	fi
}

# TODO(a.garipov): Add an analyzer to look for `fallthrough`, `goto`, and `new`?



# Checks

run_linter -e blocklist_imports

run_linter -e method_const

run_linter -e underscores

run_linter -e gofumpt --extra -e -l .

# TODO(a.garipov): golint is deprecated, find a suitable replacement.

run_linter "${GO:-go}" vet ./...

run_linter govulncheck ./...

run_linter gocyclo --over 10 .

run_linter gocognit --over 10 .

run_linter ineffassign ./...

run_linter unparam ./...

git ls-files -- 'Makefile' '*.conf' '*.go' '*.mod' '*.sh' '*.yaml' '*.yml'\
	| xargs misspell --error\
	| sed -e 's/^/misspell: /'

run_linter looppointer ./...

run_linter nilness ./...

run_linter fieldalignment ./...

run_linter -e shadow --strict ./...

# TODO(a.garipov):  Re-enable G115.
run_linter gosec --exclude G115 --quiet ./...

run_linter errcheck ./...

staticcheck_matrix='
darwin:  GOOS=darwin
freebsd: GOOS=freebsd
linux:   GOOS=linux
openbsd: GOOS=openbsd
windows: GOOS=windows
'
readonly staticcheck_matrix

echo "$staticcheck_matrix" | run_linter staticcheck --matrix ./...
0707010000008B000081A4000000000000000000000001679A649F000003E8000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/go-test.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 1

verbose="${VERBOSE:-0}"
readonly verbose

# Verbosity levels:
#   0 = Don't print anything except for errors.
#   1 = Print commands, but not nested commands.
#   2 = Print everything.
if [ "$verbose" -gt '1' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=0'
else
	set +x
	v_flags='-v=0'
	x_flags='-x=0'
fi
readonly v_flags x_flags

set -e -f -u

if [ "${RACE:-1}" -eq '0' ]
then
	race_flags='--race=0'
else
	race_flags='--race=1'
fi
readonly race_flags

go="${GO:-go}"
count_flags='--count=1'
shuffle_flags='--shuffle=on'
timeout_flags="${TIMEOUT_FLAGS:---timeout=2m}"
readonly go count_flags shuffle_flags timeout_flags

"$go" test\
	"$count_flags"\
	"$race_flags"\
	"$shuffle_flags"\
	"$timeout_flags"\
	"$v_flags"\
	"$x_flags"\
	./...
0707010000008C000081A4000000000000000000000001679A649F00000766000000000000000000000000000000000000002900000000dnsproxy-0.75.0/scripts/make/go-tools.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 3

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '1' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=0'
else
	set +x
	v_flags='-v=0'
	x_flags='-x=0'
fi
readonly v_flags x_flags

set -e -f -u

go="${GO:-go}"
readonly go

# Remove only the actual binaries in the bin/ directory, as developers may add
# their own scripts there.  Most commonly, a script named “go” for tools that
# call the go binary and need a particular version.
rm -f\
	bin/errcheck\
	bin/fieldalignment\
	bin/gocognit\
	bin/gocyclo\
	bin/gofumpt\
	bin/gosec\
	bin/govulncheck\
	bin/ineffassign\
	bin/looppointer\
	bin/misspell\
	bin/nilness\
	bin/shadow\
	bin/staticcheck\
	bin/unparam\
	;

# Reset GOARCH and GOOS to make sure we install the tools for the native
# architecture even when we're cross-compiling the main binary, and also to
# prevent the "cannot install cross-compiled binaries when GOBIN is set" error.
env\
	GOARCH=""\
	GOBIN="${PWD}/bin"\
	GOOS=""\
	GOWORK='off'\
	"$go" install\
	--modfile=./internal/tools/go.mod\
	"$v_flags"\
	"$x_flags"\
	github.com/fzipp/gocyclo/cmd/gocyclo\
	github.com/golangci/misspell/cmd/misspell\
	github.com/gordonklaus/ineffassign\
	github.com/kisielk/errcheck\
	github.com/kyoh86/looppointer/cmd/looppointer\
	github.com/securego/gosec/v2/cmd/gosec\
	github.com/uudashr/gocognit/cmd/gocognit\
	golang.org/x/tools/go/analysis/passes/fieldalignment/cmd/fieldalignment\
	golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness\
	golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow\
	golang.org/x/vuln/cmd/govulncheck\
	honnef.co/go/tools/cmd/staticcheck\
	mvdan.cc/gofumpt\
	mvdan.cc/unparam\
	;
0707010000008D000081A4000000000000000000000001679A649F000001F7000000000000000000000000000000000000002D00000000dnsproxy-0.75.0/scripts/make/go-upd-tools.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 2

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '1' ]
then
	env
	set -x
	x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
	set -x
	x_flags='-x=0'
else
	set +x
	x_flags='-x=0'
fi
readonly x_flags

set -e -f -u

go="${GO:-go}"
readonly go

cd ./internal/tools/

"$go" get -u "$x_flags"
"$go" mod tidy
0707010000008E000081A4000000000000000000000001679A649F00000651000000000000000000000000000000000000002700000000dnsproxy-0.75.0/scripts/make/helper.sh#!/bin/sh

# Common script helpers
#
# This file contains common script helpers.  It should be sourced in scripts
# right after the initial environment processing.

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 3



# Deferred helpers

not_found_msg='
looks like a binary not found error.
make sure you have installed the linter binaries using:

	$ make go-tools
'
readonly not_found_msg

not_found() {
	if [ "$?" -eq '127' ]
	then
		# Code 127 is the exit status a shell uses when a command or a file is
		# not found, according to the Bash Hackers wiki.
		#
		# See https://wiki.bash-hackers.org/dict/terms/exit_status.
		echo "$not_found_msg" 1>&2
	fi
}
trap not_found EXIT



# Helpers

# run_linter runs the given linter with two additions:
#
# 1.  If the first argument is "-e", run_linter exits with a nonzero exit code
#     if there is anything in the command's combined output.
#
# 2.  In any case, run_linter adds the program's name to its combined output.
run_linter() (
	set +e

	if [ "${VERBOSE:-0}" -lt '2' ]
	then
		set +x
	fi

	cmd="${1:?run_linter: provide a command}"
	shift

	exit_on_output='0'
	if [ "$cmd" = '-e' ]
	then
		exit_on_output='1'
		cmd="${1:?run_linter: provide a command}"
		shift
	fi

	readonly cmd

	output="$( "$cmd" "$@" )"
	exitcode="$?"

	readonly output

	if [ "$output" != '' ]
	then
		echo "$output" | sed -e "s/^/${cmd}: /"

		if [ "$exitcode" -eq '0' ] && [ "$exit_on_output" -eq '1' ]
		then
			exitcode='1'
		fi
	fi

	return "$exitcode"
)
0707010000008F000081A4000000000000000000000001679A649F0000018B000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/md-lint.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 2

verbose="${VERBOSE:-0}"
readonly verbose

set -e -f -u

if [ "$verbose" -gt '0' ]
then
	set -x
fi

# NOTE: Adjust for your project.
# markdownlint\
# 	./README.md\
# 	;

# TODO(e.burkov):  Lint README.md.
07070100000090000081A4000000000000000000000001679A649F000001CE000000000000000000000000000000000000002800000000dnsproxy-0.75.0/scripts/make/sh-lint.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 2

verbose="${VERBOSE:-0}"
readonly verbose

# Don't use -f, because we use globs in this script.
set -e -u

if [ "$verbose" -gt '0' ]
then
	set -x
fi

# NOTE: Adjust for your project.
shellcheck -e 'SC2250' -f 'gcc' -o 'all' -x --\
	./scripts/hooks/*\
	./scripts/make/*\
	;
07070100000091000081A4000000000000000000000001679A649F000006A9000000000000000000000000000000000000002900000000dnsproxy-0.75.0/scripts/make/txt-lint.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 5

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '0' ]
then
	set -x
fi

# Set $EXIT_ON_ERROR to zero to see all errors.
if [ "${EXIT_ON_ERROR:-1}" -eq '0' ]
then
	set +e
else
	set -e
fi

# We don't need glob expansions and we want to see errors about unset variables.
set -f -u

# Source the common helpers, including not_found.
. ./scripts/make/helper.sh

# Simple analyzers

# trailing_newlines is a simple check that makes sure that all plain-text files
# have a trailing newlines to make sure that all tools work correctly with them.
trailing_newlines() (
	nl="$( printf "\n" )"
	readonly nl

	# NOTE: Adjust for your project.
	git ls-files\
		':!*.bmp'\
		':!*.jpg'\
		':!*.mmdb'\
		':!*.png'\
		':!*.tar.gz'\
		':!*.webp'\
		':!*.zip'\
		| while read -r f
		do
			final_byte="$( tail -c -1 "$f" )"
			if [ "$final_byte" != "$nl" ]
			then
				printf '%s: must have a trailing newline\n' "$f"
			fi
		done
)

# trailing_whitespace is a simple check that makes sure that there are no
# trailing whitespace in plain-text files.
trailing_whitespace() {
	# NOTE: Adjust for your project.
	git ls-files\
		':!*.bmp'\
		':!*.jpg'\
		':!*.mmdb'\
		':!*.png'\
		':!*.tar.gz'\
		':!*.webp'\
		':!*.zip'\
		| while read -r f
		do
			grep -e '[[:space:]]$' -n -- "$f"\
				| sed -e "s:^:${f}\::" -e 's/ \+$/>>>&<<</'
		done
}

run_linter -e trailing_newlines

run_linter -e trailing_whitespace

git ls-files -- '*.conf' '*.md' '*.txt' '*.yaml' '*.yml'\
	| xargs misspell --error\
	| sed -e 's/^/misspell: /'
07070100000092000081A4000000000000000000000001679A649F0000025B000000000000000000000000000000000000002100000000dnsproxy-0.75.0/staticcheck.conf# This comment is used to simplify checking local copies of the staticcheck
# configuration.  Bump this number every time a significant change is made to
# this file.
#
# AdGuard-Project-Version: 1
checks = ["all"]
initialisms = [
  # See https://github.com/dominikh/go-tools/blob/master/config/config.go.
  #
  # Do not add "PTR" since we use "Ptr" as a suffix.
  "inherit"
, "ASN"
, "DHCP"
, "DNSSEC"
  # E.g. SentryDSN.
, "DSN"
, "ECS"
, "EDNS"
, "MX"
, "QUIC"
, "RA"
, "RRSIG"
, "RTT"
, "SDNS"
, "SLAAC"
, "SOA"
, "SVCB"
, "TLD"
, "WHOIS"
]
dot_import_whitelist = []
http_status_code_whitelist = []
07070100000093000041ED000000000000000000000002679A649F00000000000000000000000000000000000000000000001900000000dnsproxy-0.75.0/upstream07070100000094000081A4000000000000000000000001679A649F000010CD000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/dnscrypt.gopackage upstream

import (
	"fmt"
	"io"
	"log/slog"
	"net/url"
	"os"
	"sync"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/miekg/dns"
)

// dnsCrypt implements the [Upstream] interface for the DNSCrypt protocol.
type dnsCrypt struct {
	// mu protects client and serverInfo.
	mu *sync.RWMutex

	// client stores the DNSCrypt client properties.
	client *dnscrypt.Client

	// resolverInfo stores the DNSCrypt server properties.
	resolverInfo *dnscrypt.ResolverInfo

	// addr is the DNSCrypt server URL.
	addr *url.URL

	// logger is used for exchange logging.  It is never nil.
	logger *slog.Logger

	// verifyCert is a callback that verifies the resolver's certificate.
	verifyCert func(cert *dnscrypt.Cert) (err error)

	// timeout is the timeout for the DNS requests.
	timeout time.Duration
}

// newDNSCrypt returns a new DNSCrypt Upstream.
func newDNSCrypt(addr *url.URL, opts *Options) (u *dnsCrypt) {
	return &dnsCrypt{
		mu:         &sync.RWMutex{},
		addr:       addr,
		logger:     opts.Logger,
		verifyCert: opts.VerifyDNSCryptCertificate,
		timeout:    opts.Timeout,
	}
}

// type check
var _ Upstream = (*dnsCrypt)(nil)

// Address implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Address() string { return p.addr.String() }

// Exchange implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	resp, err = p.exchangeDNSCrypt(req)
	if errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, io.EOF) {
		// If request times out, it is possible that the server configuration
		// has been changed.  It is safe to assume that the key was rotated, see
		// https://dnscrypt.pl/2017/02/26/how-key-rotation-is-automated.
		// Re-fetch the server certificate info for new requests to not fail.
		_, _, err = p.resetClient()
		if err != nil {
			return nil, err
		}

		return p.exchangeDNSCrypt(req)
	}

	return resp, err
}

// Close implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Close() (err error) {
	return nil
}

// exchangeDNSCrypt attempts to send the DNS query and returns the response.
func (p *dnsCrypt) exchangeDNSCrypt(req *dns.Msg) (resp *dns.Msg, err error) {
	var client *dnscrypt.Client
	var resolverInfo *dnscrypt.ResolverInfo
	func() {
		p.mu.RLock()
		defer p.mu.RUnlock()

		client, resolverInfo = p.client, p.resolverInfo
	}()

	// Check the client and server info are set and the certificate is not
	// expired, since any of these cases require a client reset.
	//
	// TODO(ameshkov):  Consider using [time.Time] for [dnscrypt.Cert.NotAfter].
	switch {
	case
		client == nil,
		resolverInfo == nil,
		resolverInfo.ResolverCert.NotAfter < uint32(time.Now().Unix()):
		client, resolverInfo, err = p.resetClient()
		if err != nil {
			// Don't wrap the error, because it's informative enough as is.
			return nil, err
		}
	default:
		// Go on.
	}

	resp, err = client.Exchange(req, resolverInfo)
	if resp != nil && resp.Truncated {
		q := &req.Question[0]
		p.logger.Debug(
			"dnscrypt received truncated, falling back to tcp",
			"addr", p.addr,
			"question", q,
		)

		tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: networkTCP}
		resp, err = tcpClient.Exchange(req, resolverInfo)
	}
	if err == nil && resp != nil && resp.Id != req.Id {
		err = dns.ErrId
	}

	return resp, err
}

// resetClient renews the DNSCrypt client and server properties and also sets
// those to nil on fail.
func (p *dnsCrypt) resetClient() (client *dnscrypt.Client, ri *dnscrypt.ResolverInfo, err error) {
	addr := p.Address()

	defer func() {
		p.mu.Lock()
		defer p.mu.Unlock()

		p.client, p.resolverInfo = client, ri
	}()

	// Use UDP for DNSCrypt upstreams by default.
	client = &dnscrypt.Client{Timeout: p.timeout, Net: networkUDP}
	ri, err = client.Dial(addr)
	if err != nil {
		// Trigger client and server info renewal on the next request.
		client, ri = nil, nil
		err = fmt.Errorf("fetching certificate info from %s: %w", addr, err)
	} else if p.verifyCert != nil {
		err = p.verifyCert(ri.ResolverCert)
		if err != nil {
			// Trigger client and server info renewal on the next request.
			client, ri = nil, nil
			err = fmt.Errorf("verifying certificate info from %s: %w", addr, err)
		}
	}

	return client, ri, err
}
07070100000095000081A4000000000000000000000001679A649F00001954000000000000000000000000000000000000003300000000dnsproxy-0.75.0/upstream/dnscrypt_internal_test.gopackage upstream

import (
	"context"
	"net"
	"os"
	"strings"
	"sync/atomic"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/ameshkov/dnsstamps"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// dnsCryptHandlerFunc is a function-based implementation of the
// [dnscrypt.Handler] interface.
type dnsCryptHandlerFunc func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error)

// ServeDNS implements the [dnscrypt.Handler] interface for DNSCryptHandlerFunc.
func (f dnsCryptHandlerFunc) ServeDNS(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
	return f(w, r)
}

// startTestDNSCryptServer starts a test DNSCrypt server with the specified
// resolver config and handler.
func startTestDNSCryptServer(
	t testing.TB,
	rc dnscrypt.ResolverConfig,
	h dnscrypt.Handler,
) (stamp dnsstamps.ServerStamp) {
	t.Helper()

	cert, err := rc.CreateCert()
	require.NoError(t, err)

	s := &dnscrypt.Server{
		ProviderName: rc.ProviderName,
		ResolverCert: cert,
		Handler:      h,
	}
	testutil.CleanupAndRequireSuccess(t, func() (err error) {
		ctx, cancel := context.WithTimeout(context.Background(), timeout)
		defer cancel()

		return s.Shutdown(ctx)
	})

	localhost := netutil.IPv4Localhost().AsSlice()

	// Prepare TCP listener.
	tcpAddr := &net.TCPAddr{IP: localhost, Port: 0}
	tcpConn, err := net.ListenTCP("tcp", tcpAddr)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, tcpConn.Close)

	// Prepare UDP listener on the same port.
	port := testutil.RequireTypeAssert[*net.TCPAddr](t, tcpConn.Addr()).Port
	udpAddr := &net.UDPAddr{IP: localhost, Port: port}
	udpConn, err := net.ListenUDP("udp", udpAddr)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, udpConn.Close)

	// Start the server.
	go func() {
		udpErr := s.ServeUDP(udpConn)
		require.ErrorIs(testutil.PanicT{}, udpErr, net.ErrClosed)
	}()

	go func() {
		tcpErr := s.ServeTCP(tcpConn)
		require.NoError(testutil.PanicT{}, tcpErr)
	}()

	stamp, err = rc.CreateStamp(udpConn.LocalAddr().String())
	require.NoError(t, err)

	_, err = net.Dial("tcp", udpAddr.String())
	require.NoError(t, err)

	return stamp
}

func TestUpstreamDNSCrypt(t *testing.T) {
	t.Parallel()

	// AdGuard DNS (DNSCrypt)
	address := "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"
	u, err := AddressToUpstream(address, &Options{
		Logger:  slogutil.NewDiscardLogger(),
		Timeout: dialTimeout,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Test that it responds properly
	for range 10 {
		checkUpstream(t, u, address)
	}
}

func TestDNSCrypt_Exchange_truncated(t *testing.T) {
	// Prepare the test DNSCrypt server config
	rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
	require.NoError(t, err)

	var udpNum, tcpNum atomic.Uint32
	h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
		if w.RemoteAddr().Network() == networkUDP {
			udpNum.Add(1)
		} else {
			tcpNum.Add(1)
		}

		res := (&dns.Msg{}).SetReply(r)
		answer := &dns.TXT{
			Hdr: dns.RR_Header{
				Name:   r.Question[0].Name,
				Rrtype: dns.TypeTXT,
				Ttl:    300,
				Class:  dns.ClassINET,
			},
		}
		res.Answer = append(res.Answer, answer)

		veryLongString := strings.Repeat("VERY LONG STRING", 7)
		for range 50 {
			answer.Txt = append(answer.Txt, veryLongString)
		}

		return w.WriteMsg(res)
	})
	srvStamp := startTestDNSCryptServer(t, rc, h)

	u, err := AddressToUpstream(srvStamp.String(), &Options{
		Logger:  slogutil.NewDiscardLogger(),
		Timeout: timeout,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	req := (&dns.Msg{}).SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)

	// Check that response is not truncated (even though it's huge).
	res, err := u.Exchange(req)
	require.NoError(t, err)

	assert.False(t, res.Truncated)
	assert.Equal(t, 1, int(udpNum.Load()))
	assert.Equal(t, 1, int(tcpNum.Load()))
}

func TestDNSCrypt_Exchange_deadline(t *testing.T) {
	t.Parallel()

	// Prepare the test DNSCrypt server config
	rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
	require.NoError(t, err)

	h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
		return nil
	})

	srvStamp := startTestDNSCryptServer(t, rc, h)

	// Use a shorter timeout to speed up the test.
	u, err := AddressToUpstream(srvStamp.String(), &Options{
		Logger:  slogutil.NewDiscardLogger(),
		Timeout: 100 * time.Millisecond,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	req := (&dns.Msg{}).SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)

	res, err := u.Exchange(req)
	require.ErrorIs(t, err, os.ErrDeadlineExceeded)

	assert.Nil(t, res)
}

func TestDNSCrypt_Exchange_dialFail(t *testing.T) {
	// Prepare the test DNSCrypt server config
	rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
	require.NoError(t, err)

	h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
		return nil
	})

	req := (&dns.Msg{}).SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)
	var u Upstream

	require.True(t, t.Run("run_and_shutdown", func(t *testing.T) {
		srvStamp := startTestDNSCryptServer(t, rc, h)

		// Use a shorter timeout to speed up the test.
		u, err = AddressToUpstream(srvStamp.String(), &Options{
			Logger:  slogutil.NewDiscardLogger(),
			Timeout: 100 * time.Millisecond,
		})
		require.NoError(t, err)
	}))

	require.True(t, t.Run("dial_fail", func(t *testing.T) {
		testutil.CleanupAndRequireSuccess(t, u.Close)

		var res *dns.Msg
		res, err = u.Exchange(req)
		require.Error(t, err)

		assert.Nil(t, res)
	}))

	t.Run("restart", func(t *testing.T) {
		const validationErr errors.Error = "bad cert"

		srvStamp := startTestDNSCryptServer(t, rc, h)

		// Use a shorter timeout to speed up the test.
		u, err = AddressToUpstream(srvStamp.String(), &Options{
			Logger:  slogutil.NewDiscardLogger(),
			Timeout: 100 * time.Millisecond,
			VerifyDNSCryptCertificate: func(cert *dnscrypt.Cert) (err error) {
				return validationErr
			},
		})
		require.NoError(t, err)

		var res *dns.Msg
		res, err = u.Exchange(req)
		require.ErrorIs(t, err, validationErr)

		assert.Nil(t, res)
	})
}
07070100000096000081A4000000000000000000000001679A649F0000544E000000000000000000000000000000000000002000000000dnsproxy-0.75.0/upstream/doh.gopackage upstream

import (
	"context"
	"crypto/tls"
	"encoding/base64"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/http"
	"net/url"
	"runtime"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/httphdr"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
	"golang.org/x/net/http2"
)

// Values to configure HTTP and HTTP/2 transport.
const (
	// transportDefaultReadIdleTimeout is the default timeout for pinging
	// idle connections in HTTP/2 transport.
	transportDefaultReadIdleTimeout = 30 * time.Second

	// transportDefaultIdleConnTimeout is the default timeout for idle
	// connections in HTTP transport.
	transportDefaultIdleConnTimeout = 5 * time.Minute

	// dohMaxConnsPerHost controls the maximum number of connections for
	// each host.  Note, that setting it to 1 may cause issues with Go's http
	// implementation, see https://github.com/AdguardTeam/dnsproxy/issues/278.
	dohMaxConnsPerHost = 2

	// dohMaxIdleConns controls the maximum number of connections being idle
	// at the same time.
	dohMaxIdleConns = 2
)

// dnsOverHTTPS is a struct that implements the Upstream interface for the
// DNS-over-HTTPS protocol.
type dnsOverHTTPS struct {
	// getDialer either returns an initialized dial handler or creates a new
	// one.
	getDialer DialerInitializer

	// addr is the DNS-over-HTTPS server URL.
	addr *url.URL

	// tlsConf is the configuration of TLS.
	tlsConf *tls.Config

	// The Client's Transport typically has internal state (cached TCP
	// connections), so Clients should be reused instead of created as needed.
	// Clients are safe for concurrent use by multiple goroutines.
	client *http.Client

	// clientMu protects client.
	clientMu *sync.Mutex

	// logger is used for exchange logging.  It is never nil.
	logger *slog.Logger

	// quicConf is the QUIC configuration that is used if HTTP/3 is enabled
	// for this upstream.
	quicConf *quic.Config

	// quicConfMu protects quicConf.
	quicConfMu *sync.Mutex

	// transportH2 is an HTTP/2 transport if any.
	transportH2 *http2.Transport

	// addrRedacted is the redacted string representation of addr.  It is saved
	// separately to reduce allocations during logging and error reporting.
	addrRedacted string

	// timeout is used in HTTP client and for H3 probes.
	timeout time.Duration
}

// newDoH returns the DNS-over-HTTPS Upstream.
func newDoH(addr *url.URL, opts *Options) (u Upstream, err error) {
	addPort(addr, defaultPortDoH)

	var httpVersions []HTTPVersion
	if addr.Scheme == "h3" {
		addr.Scheme = "https"
		httpVersions = []HTTPVersion{HTTPVersion3}
	} else if httpVersions = opts.HTTPVersions; len(opts.HTTPVersions) == 0 {
		httpVersions = DefaultHTTPVersions
	}

	ups := &dnsOverHTTPS{
		getDialer: newDialerInitializer(addr, opts),
		addr:      addr,
		quicConf: &quic.Config{
			KeepAlivePeriod: QUICKeepAlivePeriod,
			TokenStore:      newQUICTokenStore(),
			Tracer:          opts.QUICTracer,
		},
		quicConfMu: &sync.Mutex{},
		tlsConf: &tls.Config{
			ServerName:   addr.Hostname(),
			RootCAs:      opts.RootCAs,
			CipherSuites: opts.CipherSuites,
			// Use the default capacity for the LRU cache.  It may be useful to
			// store several caches since the user may be routed to different
			// servers in case there's load balancing on the server-side.
			ClientSessionCache: tls.NewLRUClientSessionCache(0),
			MinVersion:         tls.VersionTLS12,
			// #nosec G402 -- TLS certificate verification could be disabled by
			// configuration.
			InsecureSkipVerify:    opts.InsecureSkipVerify,
			VerifyPeerCertificate: opts.VerifyServerCertificate,
			VerifyConnection:      opts.VerifyConnection,
		},
		clientMu:     &sync.Mutex{},
		logger:       opts.Logger,
		addrRedacted: addr.Redacted(),
		timeout:      opts.Timeout,
	}
	for _, v := range httpVersions {
		ups.tlsConf.NextProtos = append(ups.tlsConf.NextProtos, string(v))
	}

	runtime.SetFinalizer(ups, (*dnsOverHTTPS).Close)

	return ups, nil
}

// type check
var _ Upstream = (*dnsOverHTTPS)(nil)

// Address implements the [Upstream] interface for *dnsOverHTTPS.  The address
// is redacted: if the original URL of this upstream contains a userinfo with a
// password, the password is replaced with "xxxxx".
func (p *dnsOverHTTPS) Address() string { return p.addrRedacted }

// Exchange implements the [Upstream] interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	// In order to maximize HTTP cache friendliness, DoH clients using media
	// formats that include the ID field from the DNS message header, such as
	// "application/dns-message", SHOULD use a DNS ID of 0 in every DNS request.
	//
	// See https://www.rfc-editor.org/rfc/rfc8484.html.
	id := req.Id
	req.Id = 0
	defer func() {
		// Restore the original ID to not break compatibility with proxies.
		req.Id = id
		if resp != nil {
			resp.Id = id
		}
	}()

	// Check if there was already an active client before sending the request.
	// We'll only attempt to re-connect if there was one.
	client, isCached, err := p.getClient()
	if err != nil {
		return nil, fmt.Errorf("failed to init http client: %w", err)
	}

	// Make the first attempt to send the DNS query.
	resp, err = p.exchangeHTTPS(client, req)

	// Make up to 2 attempts to re-create the HTTP client and send the request
	// again.  There are several cases (mostly, with QUIC) where this workaround
	// is necessary to make HTTP client usable.  We need to make 2 attempts in
	// the case when the connection was closed (due to inactivity for example)
	// AND the server refuses to open a 0-RTT connection.
	for i := 0; isCached && p.shouldRetry(err) && i < 2; i++ {
		client, err = p.resetClient(err)
		if err != nil {
			return nil, fmt.Errorf("failed to reset http client: %w", err)
		}

		resp, err = p.exchangeHTTPS(client, req)
	}

	if err != nil {
		// If the request failed anyway, make sure we don't use this client.
		_, resErr := p.resetClient(err)

		return nil, errors.WithDeferred(err, resErr)
	}

	return resp, err
}

// Close implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Close() (err error) {
	p.clientMu.Lock()
	defer p.clientMu.Unlock()

	runtime.SetFinalizer(p, nil)

	if p.client != nil {
		err = p.closeClient(p.client)
	}

	return err
}

// closeClient cleans up resources used by client if necessary.  Note that this
// should be done for HTTP/3, as it can lead to resource leaks due to keep-alive
// connections, and for HTTP/2 due to idle connections.
func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
	if isHTTP3(client) {
		return client.Transport.(io.Closer).Close()
	} else if p.transportH2 != nil {
		p.transportH2.CloseIdleConnections()
	}

	return nil
}

// exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient.
func (p *dnsOverHTTPS) exchangeHTTPS(client *http.Client, req *dns.Msg) (resp *dns.Msg, err error) {
	n := networkTCP
	if isHTTP3(client) {
		n = networkUDP
	}

	logBegin(p.logger, p.addrRedacted, n, req)
	defer func() { logFinish(p.logger, p.addrRedacted, n, err) }()

	return p.exchangeHTTPSClient(client, req)
}

// exchangeHTTPSClient sends the DNS query to a DoH resolver using the specified
// http.Client instance.
func (p *dnsOverHTTPS) exchangeHTTPSClient(
	client *http.Client,
	req *dns.Msg,
) (resp *dns.Msg, err error) {
	buf, err := req.Pack()
	if err != nil {
		return nil, fmt.Errorf("packing message: %w", err)
	}

	// It appears, that GET requests are more memory-efficient with Golang
	// implementation of HTTP/2.
	method := http.MethodGet
	if isHTTP3(client) {
		// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
		method = http3.MethodGet0RTT
	}

	q := url.Values{
		"dns": []string{base64.RawURLEncoding.EncodeToString(buf)},
	}

	u := url.URL{
		Scheme:   p.addr.Scheme,
		User:     p.addr.User,
		Host:     p.addr.Host,
		Path:     p.addr.Path,
		RawQuery: q.Encode(),
	}

	httpReq, err := http.NewRequest(method, u.String(), nil)
	if err != nil {
		return nil, fmt.Errorf("creating http request to %s: %w", p.addrRedacted, err)
	}

	// Prevent the client from sending User-Agent header, see
	// https://github.com/AdguardTeam/dnsproxy/issues/211.
	httpReq.Header.Set(httphdr.UserAgent, "")
	httpReq.Header.Set(httphdr.Accept, "application/dns-message")

	httpResp, err := client.Do(httpReq)
	if err != nil {
		return nil, fmt.Errorf("requesting %s: %w", p.addrRedacted, err)
	}
	defer slogutil.CloseAndLog(httpReq.Context(), p.logger, httpResp.Body, slog.LevelDebug)

	body, err := io.ReadAll(httpResp.Body)
	if err != nil {
		return nil, fmt.Errorf("reading %s: %w", p.addrRedacted, err)
	}

	if httpResp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf(
			"expected status %d, got %d from %s",
			http.StatusOK,
			httpResp.StatusCode,
			p.addrRedacted,
		)
	}

	resp = &dns.Msg{}
	err = resp.Unpack(body)
	if err != nil {
		return nil, fmt.Errorf(
			"unpacking response from %s: body is %s: %w",
			p.addrRedacted,
			body,
			err,
		)
	}

	if resp.Id != req.Id {
		err = dns.ErrId
	}

	return resp, err
}

// shouldRetry checks what error we have received and returns true if we should
// re-create the HTTP client and retry the request.
func (p *dnsOverHTTPS) shouldRetry(err error) (ok bool) {
	if err == nil {
		return false
	}

	var netErr net.Error
	if errors.As(err, &netErr) && netErr.Timeout() {
		// If this is a timeout error, trying to forcibly re-create the HTTP
		// client instance.  This is an attempt to fix an issue with DoH client
		// stalling after a network change.
		//
		// See https://github.com/AdguardTeam/AdGuardHome/issues/3217.
		return true
	}

	if isQUICRetryError(err) {
		return true
	}

	return false
}

// resetClient triggers re-creation of the *http.Client that is used by this
// upstream.  This method accepts the error that caused resetting client as
// depending on the error we may also reset the QUIC config.
func (p *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) {
	p.clientMu.Lock()
	defer p.clientMu.Unlock()

	if errors.Is(resetErr, quic.Err0RTTRejected) {
		// Reset the TokenStore only if 0-RTT was rejected.
		p.resetQUICConfig()
	}

	oldClient := p.client
	if oldClient != nil {
		closeErr := p.closeClient(oldClient)
		if closeErr != nil {
			p.logger.Warn("failed to close the old http client", slogutil.KeyError, closeErr)
		}
	}

	p.logger.Debug("recreating the http client", slogutil.KeyError, resetErr)
	p.client, err = p.createClient()

	return p.client, err
}

// getQUICConfig returns the QUIC config in a thread-safe manner.  Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverHTTPS) getQUICConfig() (c *quic.Config) {
	p.quicConfMu.Lock()
	defer p.quicConfMu.Unlock()

	return p.quicConf
}

// resetQUICConfig Re-create the token store to make sure we're not trying to
// use invalid for 0-RTT.
func (p *dnsOverHTTPS) resetQUICConfig() {
	p.quicConfMu.Lock()
	defer p.quicConfMu.Unlock()

	p.quicConf = p.quicConf.Clone()
	p.quicConf.TokenStore = newQUICTokenStore()
}

// getClient gets or lazily initializes an HTTP client (and transport) that will
// be used for this DoH resolver.
func (p *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) {
	startTime := time.Now()

	p.clientMu.Lock()
	defer p.clientMu.Unlock()

	if p.client != nil {
		return p.client, true, nil
	}

	// Timeout can be exceeded while waiting for the lock. This happens quite
	// often on mobile devices.
	elapsed := time.Since(startTime)
	if p.timeout > 0 && elapsed > p.timeout {
		return nil, false, fmt.Errorf("timeout exceeded: %s", elapsed)
	}

	p.logger.Debug("creating a new http client")
	p.client, err = p.createClient()

	return p.client, false, err
}

// createClient creates a new *http.Client instance.  The HTTP protocol version
// will depend on whether HTTP3 is allowed and provided by this upstream.  Note,
// that we'll attempt to establish a QUIC connection when creating the client in
// order to check whether HTTP3 is supported.
func (p *dnsOverHTTPS) createClient() (*http.Client, error) {
	transport, err := p.createTransport()
	if err != nil {
		return nil, fmt.Errorf("initializing http transport: %w", err)
	}

	client := &http.Client{
		Transport: transport,
		// TODO(ameshkov):  p.timeout may appear zero that will disable the
		// timeout for client, consider using the default.
		Timeout: p.timeout,
		Jar:     nil,
	}

	p.client = client

	return p.client, nil
}

// createTransport initializes an HTTP transport that will be used specifically
// for this DoH resolver.  This HTTP transport ensures that the HTTP requests
// will be sent exactly to the IP address got from the bootstrap resolver. Note,
// that this function will first attempt to establish a QUIC connection (if
// HTTP3 is enabled in the upstream options).  If this attempt is successful,
// it returns an HTTP3 transport, otherwise it returns the H1/H2 transport.
func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
	dialContext, err := p.getDialer()
	if err != nil {
		return nil, fmt.Errorf("bootstrapping %s: %w", p.addrRedacted, err)
	}

	// First, we attempt to create an HTTP3 transport.  If the probe QUIC
	// connection is established successfully, we'll be using HTTP3 for this
	// upstream.
	tlsConf := p.tlsConf.Clone()
	transportH3, err := p.createTransportH3(tlsConf, dialContext)
	if err == nil {
		p.logger.Debug("using http/3 for this upstream, quic was faster")

		return transportH3, nil
	}

	p.logger.Debug("got error, switching to http/2 for this upstream", slogutil.KeyError, err)

	if !p.supportsHTTP() {
		return nil, errors.Error("HTTP1/1 and HTTP2 are not supported by this upstream")
	}

	transport := &http.Transport{
		TLSClientConfig:    tlsConf,
		DisableCompression: true,
		DialContext:        dialContext,
		IdleConnTimeout:    transportDefaultIdleConnTimeout,
		MaxConnsPerHost:    dohMaxConnsPerHost,
		MaxIdleConns:       dohMaxIdleConns,
		// Since we have a custom DialContext, we need to use this field to make
		// golang http.Client attempt to use HTTP/2. Otherwise, it would only be
		// used when negotiated on the TLS level.
		ForceAttemptHTTP2: true,
	}

	// Explicitly configure transport to use HTTP/2.
	//
	// See https://github.com/AdguardTeam/dnsproxy/issues/11.
	p.transportH2, err = http2.ConfigureTransports(transport)
	if err != nil {
		return nil, err
	}

	// Enable HTTP/2 pings on idle connections.
	p.transportH2.ReadIdleTimeout = transportDefaultReadIdleTimeout

	return transport, nil
}

// http3Transport is a wrapper over [*http3.Transport] that tries to optimize
// its behavior.  The main thing that it does is trying to force use a single
// connection to a host instead of creating a new one all the time.  It also
// helps mitigate race issues with quic-go.
type http3Transport struct {
	baseTransport *http3.Transport

	closed bool
	mu     sync.RWMutex
}

// type check
var _ http.RoundTripper = (*http3Transport)(nil)

// RoundTrip implements the http.RoundTripper interface for *http3Transport.
func (h *http3Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
	h.mu.RLock()
	defer h.mu.RUnlock()

	if h.closed {
		return nil, net.ErrClosed
	}

	// Try to use cached connection to the target host if it's available.
	resp, err = h.baseTransport.RoundTripOpt(req, http3.RoundTripOpt{OnlyCachedConn: true})

	if errors.Is(err, http3.ErrNoCachedConn) {
		// If there are no cached connection, trigger creating a new one.
		resp, err = h.baseTransport.RoundTrip(req)
	}

	return resp, err
}

// type check
var _ io.Closer = (*http3Transport)(nil)

// Close implements the io.Closer interface for *http3Transport.
func (h *http3Transport) Close() (err error) {
	h.mu.Lock()
	defer h.mu.Unlock()

	h.closed = true

	return h.baseTransport.Close()
}

// createTransportH3 tries to create an HTTP/3 transport for this upstream.  We
// should be able to fall back to H1/H2 in case if HTTP/3 is unavailable or if
// it is too slow.  In order to do that, this method will run two probes in
// parallel (one for TLS, the other one for QUIC) and if QUIC is faster it will
// create the [*http3.Transport] instance.
func (p *dnsOverHTTPS) createTransportH3(
	tlsConfig *tls.Config,
	dialContext bootstrap.DialHandler,
) (roundTripper http.RoundTripper, err error) {
	if !p.supportsH3() {
		return nil, errors.Error("HTTP3 support is not enabled")
	}

	addr, err := p.probeH3(tlsConfig, dialContext)
	if err != nil {
		return nil, err
	}

	rt := &http3.Transport{
		Dial: func(
			ctx context.Context,

			// Ignore the address and always connect to the one that we got
			// from the bootstrapper.
			_ string,
			tlsCfg *tls.Config,
			cfg *quic.Config,
		) (c quic.EarlyConnection, err error) {
			c, err = quic.DialAddrEarly(ctx, addr, tlsCfg, cfg)
			return c, err
		},
		DisableCompression: true,
		TLSClientConfig:    tlsConfig,
		QUICConfig:         p.getQUICConfig(),
	}

	return &http3Transport{baseTransport: rt}, nil
}

// probeH3 runs a test to check whether QUIC is faster than TLS for this
// upstream.  If the test is successful it will return the address that we
// should use to establish the QUIC connections.
func (p *dnsOverHTTPS) probeH3(
	tlsConfig *tls.Config,
	dialContext bootstrap.DialHandler,
) (addr string, err error) {
	// We're using bootstrapped address instead of what's passed to the function
	// it does not create an actual connection, but it helps us determine
	// what IP is actually reachable (when there are v4/v6 addresses).
	rawConn, err := dialContext(context.Background(), "udp", "")
	if err != nil {
		return "", fmt.Errorf("failed to dial: %w", err)
	}
	// It's never actually used.
	_ = rawConn.Close()

	udpConn, ok := rawConn.(*net.UDPConn)
	if !ok {
		return "", fmt.Errorf("not a UDP connection to %s", p.addrRedacted)
	}

	addr = udpConn.RemoteAddr().String()

	// Avoid spending time on probing if this upstream only supports HTTP/3.
	if p.supportsH3() && !p.supportsHTTP() {
		return addr, nil
	}

	// Use a new *tls.Config with empty session cache for probe connections.
	// Surprisingly, this is really important since otherwise it invalidates
	// the existing cache.
	// TODO(ameshkov): figure out why the sessions cache invalidates here.
	probeTLSCfg := tlsConfig.Clone()
	probeTLSCfg.ClientSessionCache = nil

	// Do not expose probe connections to the callbacks that are passed to
	// the bootstrap options to avoid side-effects.
	// TODO(ameshkov): consider exposing, somehow mark that this is a probe.
	probeTLSCfg.VerifyPeerCertificate = nil
	probeTLSCfg.VerifyConnection = nil

	// Run probeQUIC and probeTLS in parallel and see which one is faster.
	chQUIC := make(chan error, 1)
	chTLS := make(chan error, 1)
	go p.probeQUIC(addr, probeTLSCfg, chQUIC)
	go p.probeTLS(dialContext, probeTLSCfg, chTLS)

	select {
	case quicErr := <-chQUIC:
		if quicErr != nil {
			// QUIC failed, return error since HTTP3 was not preferred.
			return "", quicErr
		}

		// Return immediately, QUIC was faster.
		return addr, quicErr
	case tlsErr := <-chTLS:
		if tlsErr != nil {
			// Return immediately, TLS failed.
			p.logger.Debug("probing tls", slogutil.KeyError, tlsErr)

			return addr, nil
		}

		return "", errors.Error("TLS was faster than QUIC, prefer it")
	}
}

// probeQUIC attempts to establish a QUIC connection to the specified address.
// We run probeQUIC and probeTLS in parallel and see which one is faster.
func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) {
	startTime := time.Now()

	t := p.timeout
	if t == 0 {
		t = dialTimeout
	}
	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(t))
	defer cancel()

	conn, err := quic.DialAddrEarly(ctx, addr, tlsConfig, p.getQUICConfig())
	if err != nil {
		ch <- fmt.Errorf("opening quic connection to %s: %w", p.addrRedacted, err)
		return
	}

	// Ignore the error since there's no way we can use it for anything useful.
	_ = conn.CloseWithError(QUICCodeNoError, "")

	ch <- nil

	elapsed := time.Since(startTime)
	p.logger.Debug("quic connection established", "elapsed", elapsed)
}

// probeTLS attempts to establish a TLS connection to the specified address. We
// run probeQUIC and probeTLS in parallel and see which one is faster.
func (p *dnsOverHTTPS) probeTLS(dialContext bootstrap.DialHandler, tlsConfig *tls.Config, ch chan error) {
	startTime := time.Now()

	conn, err := tlsDial(dialContext, tlsConfig)
	if err != nil {
		ch <- fmt.Errorf("opening TLS connection: %w", err)
		return
	}

	// Ignore the error since there's no way we can use it for anything useful.
	_ = conn.Close()

	ch <- nil

	elapsed := time.Since(startTime)
	p.logger.Debug("tls connection established", "elapsed", elapsed)
}

// supportsH3 returns true if HTTP/3 is supported by this upstream.
func (p *dnsOverHTTPS) supportsH3() (ok bool) {
	for _, v := range p.tlsConf.NextProtos {
		if v == string(HTTPVersion3) {
			return true
		}
	}

	return false
}

// supportsHTTP returns true if HTTP/1.1 or HTTP2 is supported by this upstream.
func (p *dnsOverHTTPS) supportsHTTP() (ok bool) {
	for _, v := range p.tlsConf.NextProtos {
		if v == string(HTTPVersion11) || v == string(HTTPVersion2) {
			return true
		}
	}

	return false
}

// isHTTP3 checks if the *http.Client is an HTTP/3 client.
func isHTTP3(client *http.Client) (ok bool) {
	_, ok = client.Transport.(*http3Transport)

	return ok
}
07070100000097000081A4000000000000000000000001679A649F0000375E000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/upstream/doh_internal_test.gopackage upstream

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/base64"
	"fmt"
	"net"
	"net/http"
	"net/netip"
	"net/url"
	"sync/atomic"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
	"github.com/stretchr/testify/require"
)

func TestUpstreamDoH(t *testing.T) {
	t.Parallel()

	testCases := []struct {
		name             string
		expectedProtocol HTTPVersion
		httpVersions     []HTTPVersion
		delayHandshakeH3 time.Duration
		delayHandshakeH2 time.Duration
		http3Enabled     bool
	}{{
		name:             "http1.1_h2",
		http3Enabled:     false,
		httpVersions:     []HTTPVersion{HTTPVersion11, HTTPVersion2},
		expectedProtocol: HTTPVersion2,
	}, {
		name:             "fallback_to_http2",
		http3Enabled:     false,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		expectedProtocol: HTTPVersion2,
	}, {
		name:             "http3",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3},
		expectedProtocol: HTTPVersion3,
	}, {
		name:             "race_http3_faster",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		delayHandshakeH2: time.Second,
		expectedProtocol: HTTPVersion3,
	}, {
		name:             "race_http2_faster",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		delayHandshakeH3: time.Second,
		expectedProtocol: HTTPVersion2,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			srv := startDoHServer(t, testDoHServerOptions{
				http3Enabled:     tc.http3Enabled,
				delayHandshakeH2: tc.delayHandshakeH2,
				delayHandshakeH3: tc.delayHandshakeH3,
			})

			// Create a DNS-over-HTTPS upstream.
			address := fmt.Sprintf("https://%s/dns-query", srv.addr)

			var lastState tls.ConnectionState
			opts := &Options{
				Logger:             slogutil.NewDiscardLogger(),
				InsecureSkipVerify: true,
				HTTPVersions:       tc.httpVersions,
				VerifyConnection: func(state tls.ConnectionState) (err error) {
					if state.NegotiatedProtocol != string(tc.expectedProtocol) {
						return fmt.Errorf(
							"expected %s, got %s",
							tc.expectedProtocol,
							state.NegotiatedProtocol,
						)
					}
					lastState = state
					return nil
				},
			}
			u, err := AddressToUpstream(address, opts)
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			// Test that it responds properly.
			for range 10 {
				checkUpstream(t, u, address)
			}

			doh := u.(*dnsOverHTTPS)

			// Trigger re-connection.
			doh.client = nil

			// Force it to establish the connection again.
			checkUpstream(t, u, address)

			// Check that TLS session was resumed properly.
			require.True(t, lastState.DidResume)
		})
	}
}

func TestUpstreamDoH_raceReconnect(t *testing.T) {
	t.Parallel()

	testCases := []struct {
		name             string
		expectedProtocol HTTPVersion
		httpVersions     []HTTPVersion
		delayHandshakeH3 time.Duration
		delayHandshakeH2 time.Duration
		http3Enabled     bool
	}{{
		name:             "http1.1_h2",
		http3Enabled:     false,
		httpVersions:     []HTTPVersion{HTTPVersion11, HTTPVersion2},
		expectedProtocol: HTTPVersion2,
	}, {
		name:             "fallback_to_http2",
		http3Enabled:     false,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		expectedProtocol: HTTPVersion2,
	}, {
		name:             "http3",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3},
		expectedProtocol: HTTPVersion3,
	}, {
		name:             "race_http3_faster",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		delayHandshakeH2: time.Second,
		expectedProtocol: HTTPVersion3,
	}, {
		name:             "race_http2_faster",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		delayHandshakeH3: time.Second,
		expectedProtocol: HTTPVersion2,
	}}

	// This is a different set of tests that are supposed to be run with -race.
	// The difference is that the HTTP handler here adds additional time.Sleep
	// call.  This call would trigger the HTTP client re-connection which is
	// important to test for race conditions.
	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			const timeout = time.Millisecond * 100
			var requestsCount int32

			handlerFunc := createDoHHandlerFunc()
			mux := http.NewServeMux()
			mux.HandleFunc("/dns-query", func(w http.ResponseWriter, r *http.Request) {
				newVal := atomic.AddInt32(&requestsCount, 1)
				if newVal%10 == 0 {
					time.Sleep(timeout * 2)
				}
				handlerFunc(w, r)
			})

			srv := startDoHServer(t, testDoHServerOptions{
				http3Enabled:     tc.http3Enabled,
				delayHandshakeH2: tc.delayHandshakeH2,
				delayHandshakeH3: tc.delayHandshakeH3,
				handler:          mux,
			})

			// Create a DNS-over-HTTPS upstream that will be used for the
			// race test.
			address := fmt.Sprintf("https://%s/dns-query", srv.addr)
			opts := &Options{
				Logger:             slogutil.NewDiscardLogger(),
				InsecureSkipVerify: true,
				HTTPVersions:       tc.httpVersions,
				Timeout:            timeout,
			}
			u, err := AddressToUpstream(address, opts)
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			checkRaceCondition(u)
		})
	}
}

func TestUpstreamDoH_serverRestart(t *testing.T) {
	testCases := []struct {
		name         string
		httpVersions []HTTPVersion
	}{{
		name:         "http2",
		httpVersions: []HTTPVersion{HTTPVersion11, HTTPVersion2},
	}, {
		name:         "http3",
		httpVersions: []HTTPVersion{HTTPVersion3},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			var addr netip.AddrPort
			var upsAddr string
			var u Upstream

			t.Run("first_try", func(t *testing.T) {
				srv := startDoHServer(t, testDoHServerOptions{
					http3Enabled: true,
				})

				addr = netip.MustParseAddrPort(srv.addr)
				upsAddr = (&url.URL{
					Scheme: "https",
					Host:   addr.String(),
					Path:   "dns-query",
				}).String()

				var err error
				u, err = AddressToUpstream(upsAddr, &Options{
					Logger:             slogutil.NewDiscardLogger(),
					InsecureSkipVerify: true,
					HTTPVersions:       tc.httpVersions,
					Timeout:            100 * time.Millisecond,
				})
				require.NoError(t, err)

				checkUpstream(t, u, upsAddr)
			})
			require.False(t, t.Failed())
			testutil.CleanupAndRequireSuccess(t, u.Close)

			t.Run("second_try", func(t *testing.T) {
				_ = startDoHServer(t, testDoHServerOptions{
					http3Enabled: true,
					port:         int(addr.Port()),
				})

				checkUpstream(t, u, upsAddr)
			})
			require.False(t, t.Failed())

			t.Run("retry", func(t *testing.T) {
				_, err := u.Exchange(createTestMessage())
				require.Error(t, err)

				_ = startDoHServer(t, testDoHServerOptions{
					http3Enabled: true,
					port:         int(addr.Port()),
				})

				checkUpstream(t, u, upsAddr)
			})
		})
	}
}

func TestUpstreamDoH_0RTT(t *testing.T) {
	t.Parallel()

	// Run the first server instance.
	srv := startDoHServer(t, testDoHServerOptions{
		http3Enabled: true,
	})

	// Create a DNS-over-HTTPS upstream.
	tracer := &quicTracer{}
	address := fmt.Sprintf("h3://%s/dns-query", srv.addr)
	u, err := AddressToUpstream(address, &Options{
		Logger:             slogutil.NewDiscardLogger(),
		InsecureSkipVerify: true,
		QUICTracer:         tracer.TracerForConnection,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	uh := u.(*dnsOverHTTPS)
	req := createTestMessage()

	// Trigger connection to a DoH3 server.
	resp, err := uh.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, resp)

	// Close the active connection to make sure we'll reconnect.
	func() {
		uh.clientMu.Lock()
		defer uh.clientMu.Unlock()

		err = uh.closeClient(uh.client)
		require.NoError(t, err)

		uh.client = nil
	}()

	// Trigger second connection.
	resp, err = uh.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, resp)

	// Check traced connections info.
	conns := tracer.getConnectionsInfo()
	require.Len(t, conns, 2)

	// Examine the first connection (no 0-RTT there).
	require.False(t, conns[0].is0RTT())

	// Examine the second connection (the one that used 0-RTT).
	require.True(t, conns[1].is0RTT())
}

// testDoHServerOptions allows customizing testDoHServer behavior.
type testDoHServerOptions struct {
	// handler is an HTTP handler that should be used by the server.  The
	// default one is used on nil.
	handler http.Handler
	// delayHandshakeH2 is a delay that should be added to the handshake of the
	// HTTP/2 server.
	delayHandshakeH2 time.Duration
	// delayHandshakeH3 is a delay that should be added to the handshake of the
	// HTTP/3 server.
	delayHandshakeH3 time.Duration
	// port is the port that the server should listen to.  If it's 0, a random
	// port is used.
	port int
	// http3Enabled is a flag that indicates whether the server should start an
	// HTTP/3 server.
	http3Enabled bool
}

// testDoHServer is an instance of a test DNS-over-HTTPS server.
type testDoHServer struct {
	// tlsConfig is the TLS configuration that is used for this server.
	tlsConfig *tls.Config

	// rootCAs is the pool with root certificates used by the test server.
	rootCAs *x509.CertPool

	// server is an HTTP/1.1 and HTTP/2 server.
	server *http.Server

	// serverH3 is an HTTP/3 server.
	serverH3 *http3.Server

	// listenerH3 that's used to serve HTTP/3.
	listenerH3 *quic.EarlyListener

	// addr is the address that this server listens to.
	addr string
}

// Shutdown stops the DoH server.
func (s *testDoHServer) Shutdown() {
	if s.server != nil {
		_ = s.server.Shutdown(context.Background())
	}

	if s.serverH3 != nil {
		_ = s.serverH3.Close()
		_ = s.listenerH3.Close()
	}
}

// startDoHServer starts a new DNS-over-HTTPS server with specified options.  It
// returns a started server instance with addr set.  Note that it adds its own
// shutdown to cleanup of t.
func startDoHServer(
	t *testing.T,
	opts testDoHServerOptions,
) (s *testDoHServer) {
	tlsConfig, rootCAs := createServerTLSConfig(t, "127.0.0.1")
	handler := opts.handler
	if handler == nil {
		handler = createDoHHandler()
	}

	// Step one is to create a regular HTTP server, we'll always have it
	// running.
	server := &http.Server{
		Handler:     handler,
		ReadTimeout: time.Second,
	}

	// Listen TCP first.
	listenAddr := fmt.Sprintf("127.0.0.1:%d", opts.port)
	tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
	require.NoError(t, err)

	tcpListen, err := net.ListenTCP("tcp", tcpAddr)
	require.NoError(t, err)

	tlsConfigH2 := tlsConfig.Clone()
	tlsConfigH2.NextProtos = []string{string(HTTPVersion2), string(HTTPVersion11)}
	tlsConfigH2.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
		if opts.delayHandshakeH2 > 0 {
			time.Sleep(opts.delayHandshakeH2)
		}
		return nil, nil
	}
	tlsListen := tls.NewListener(tcpListen, tlsConfigH2)

	// Run the H1/H2 server.
	go func() {
		// TODO(ameshkov): check the error here.
		_ = server.Serve(tlsListen)
	}()

	// Get the real address that the listener now listens to.
	tcpAddr = tcpListen.Addr().(*net.TCPAddr)

	var serverH3 *http3.Server
	var listenerH3 *quic.EarlyListener

	if opts.http3Enabled {
		tlsConfigH3 := tlsConfig.Clone()
		tlsConfigH3.NextProtos = []string{string(HTTPVersion3)}
		tlsConfigH3.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
			if opts.delayHandshakeH3 > 0 {
				time.Sleep(opts.delayHandshakeH3)
			}
			return nil, nil
		}

		serverH3 = &http3.Server{
			Handler: handler,
		}

		// Listen UDP for the H3 server. Reuse the same port as was used for the
		// TCP listener.
		var udpAddr *net.UDPAddr
		udpAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port))
		require.NoError(t, err)

		var conn net.PacketConn
		conn, err = net.ListenUDP("udp", udpAddr)
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, conn.Close)

		transport := &quic.Transport{
			Conn:                conn,
			VerifySourceAddress: func(net.Addr) bool { return false },
		}

		// QUIC configuration with the 0-RTT support enabled by default.
		listenerH3, err = transport.ListenEarly(tlsConfigH3, &quic.Config{
			Allow0RTT: true,
		})
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, transport.Close)

		// Run the H3 server.
		go func() {
			// TODO(ameshkov): check the error here.
			_ = serverH3.ServeListener(listenerH3)
		}()
	}

	s = &testDoHServer{
		tlsConfig:  tlsConfig,
		rootCAs:    rootCAs,
		server:     server,
		serverH3:   serverH3,
		listenerH3: listenerH3,
		// Save the address that the server listens to.
		addr: tcpAddr.String(),
	}
	t.Cleanup(s.Shutdown)

	return s
}

// createDoHHandlerFunc creates a simple http.HandlerFunc that reads the
// incoming DNS message and returns the test response.
func createDoHHandlerFunc() (f http.HandlerFunc) {
	return func(w http.ResponseWriter, r *http.Request) {
		dnsParam := r.URL.Query().Get("dns")
		buf, err := base64.RawURLEncoding.DecodeString(dnsParam)
		if err != nil {
			http.Error(
				w,
				fmt.Sprintf("internal error: %s", err),
				http.StatusInternalServerError,
			)
			return
		}

		m := &dns.Msg{}
		err = m.Unpack(buf)
		if err != nil {
			http.Error(
				w,
				fmt.Sprintf("internal error: %s", err),
				http.StatusInternalServerError,
			)
			return
		}

		resp := respondToTestMessage(m)

		buf, err = resp.Pack()
		if err != nil {
			http.Error(
				w,
				fmt.Sprintf("internal error: %s", err),
				http.StatusInternalServerError,
			)
			return
		}

		w.Header().Set("Content-Type", "application/dns-message")

		_, err = w.Write(buf)
		if err != nil {
			panic(fmt.Errorf("unexpected error on writing response: %w", err))
		}
	}
}

// createDoHHandler returns a very simple http.Handler that reads the incoming
// request and returns with a test message.
func createDoHHandler() (h http.Handler) {
	mux := http.NewServeMux()
	mux.HandleFunc("/dns-query", createDoHHandlerFunc())

	return mux
}
07070100000098000081A4000000000000000000000001679A649F00003DB0000000000000000000000000000000000000002000000000dnsproxy-0.75.0/upstream/doq.gopackage upstream

import (
	"context"
	"crypto/tls"
	"fmt"
	"log/slog"
	"net"
	"net/url"
	"os"
	"runtime"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
)

const (
	// QUICCodeNoError is used when the connection or stream needs to be closed,
	// but there is no error to signal.
	QUICCodeNoError = quic.ApplicationErrorCode(0)

	// QUICCodeInternalError signals that the DoQ implementation encountered
	// an internal error and is incapable of pursuing the transaction or the
	// connection.
	QUICCodeInternalError = quic.ApplicationErrorCode(1)

	// QUICKeepAlivePeriod is the value that we pass to *quic.Config and that
	// controls the period with with keep-alive frames are being sent to the
	// connection. We set it to 20s as it would be in the quic-go@v0.27.1 with
	// KeepAlive field set to true This value is specified in
	// https://pkg.go.dev/github.com/quic-go/quic-go/internal/protocol#MaxKeepAliveInterval.
	//
	// TODO(ameshkov):  Consider making it configurable.
	QUICKeepAlivePeriod = time.Second * 20

	// NextProtoDQ is the ALPN token for DoQ. During the connection establishment,
	// DNS/QUIC support is indicated by selecting the ALPN token "doq" in the
	// crypto handshake.
	//
	// See https://datatracker.ietf.org/doc/rfc9250.
	NextProtoDQ = "doq"
)

// compatProtoDQ is a list of ALPN tokens used by a QUIC connection.
// NextProtoDQ is the latest draft version supported by dnsproxy, but it also
// includes previous drafts.
var compatProtoDQ = []string{NextProtoDQ, "doq-i00", "dq", "doq-i02"}

// dnsOverQUIC implements the [Upstream] interface for the DNS-over-QUIC
// protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html).
type dnsOverQUIC struct {
	// getDialer either returns an initialized dial handler or creates a new
	// one.
	getDialer DialerInitializer

	// addr is the DNS-over-QUIC server URL.
	addr *url.URL

	// tlsConf is the configuration of TLS.
	tlsConf *tls.Config

	// quicConfig is the QUIC configuration that is used for establishing
	// connections to the upstream.  This configuration includes the TokenStore
	// that needs to be stored for the lifetime of dnsOverQUIC since we can
	// re-create the connection.
	quicConfig *quic.Config

	// conn is the current active QUIC connection.  It can be closed and
	// re-opened when needed.
	conn quic.Connection

	// bytesPool is a *sync.Pool we use to store byte buffers in.  These byte
	// buffers are used to read responses from the upstream.
	bytesPool *sync.Pool

	// quicConfigMu protects quicConfig.
	quicConfigMu *sync.Mutex

	// connMu protects conn.
	connMu *sync.Mutex

	// bytesPoolGuard protects bytesPool.
	bytesPoolMu *sync.Mutex

	// logger is used for exchange logging.  It is never nil.
	logger *slog.Logger

	// timeout is the timeout for the upstream connection.
	timeout time.Duration
}

// newDoQ returns the DNS-over-QUIC Upstream.
func newDoQ(addr *url.URL, opts *Options) (u Upstream, err error) {
	addPort(addr, defaultPortDoQ)

	u = &dnsOverQUIC{
		getDialer: newDialerInitializer(addr, opts),
		addr:      addr,
		quicConfig: &quic.Config{
			KeepAlivePeriod: QUICKeepAlivePeriod,
			TokenStore:      newQUICTokenStore(),
			Tracer:          opts.QUICTracer,
		},
		tlsConf: &tls.Config{
			ServerName:   addr.Hostname(),
			RootCAs:      opts.RootCAs,
			CipherSuites: opts.CipherSuites,
			// Use the default capacity for the LRU cache.  It may be useful to
			// store several caches since the user may be routed to different
			// servers in case there's load balancing on the server-side.
			ClientSessionCache: tls.NewLRUClientSessionCache(0),
			MinVersion:         tls.VersionTLS12,
			// #nosec G402 -- TLS certificate verification could be disabled by
			// configuration.
			InsecureSkipVerify:    opts.InsecureSkipVerify,
			VerifyPeerCertificate: opts.VerifyServerCertificate,
			VerifyConnection:      opts.VerifyConnection,
			NextProtos:            compatProtoDQ,
		},
		quicConfigMu: &sync.Mutex{},
		connMu:       &sync.Mutex{},
		bytesPoolMu:  &sync.Mutex{},
		logger:       opts.Logger,
		timeout:      opts.Timeout,
	}

	runtime.SetFinalizer(u, (*dnsOverQUIC).Close)

	return u, nil
}

// type check
var _ Upstream = (*dnsOverQUIC)(nil)

// Address implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Address() string { return p.addr.String() }

// Exchange implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	// When sending queries over a QUIC connection, the DNS Message ID MUST be
	// set to 0.  The stream mapping for DoQ allows for unambiguous correlation
	// of queries and responses, so the Message ID field is not required.
	//
	// See https://www.rfc-editor.org/rfc/rfc9250#section-4.2.1.
	id := req.Id
	req.Id = 0
	defer func() {
		// Restore the original ID to not break compatibility with proxies.
		req.Id = id
		if resp != nil {
			resp.Id = id
		}
	}()

	// Gets or opens a QUIC connection to use for this query.
	conn, cached, err := p.getConnection()
	if err != nil {
		return nil, fmt.Errorf("getting conn: %w", err)
	}

	// Make the first attempt to send the DNS query.
	resp, err = p.exchangeQUIC(req, conn)

	// Failure to use a cached connection should be handled gracefully as this
	// connection could have been closed by the server or simply be broken due
	// to how UDP NAT works.  In this case the connection should be re-created.
	if cached && err != nil {
		p.logger.Debug("recreating the quic connection and retrying", slogutil.KeyError, err)

		// Close the active connection to make sure the cached connection is
		// cleaned up.
		p.closeConnWithError(conn, err)

		// Get or re-create the QUIC connection in order to make the second
		// attempt.
		conn, _, err = p.getConnection()
		if err != nil {
			return nil, fmt.Errorf("getting new conn: %w", err)
		}

		// Retry sending the request through the new connection.
		resp, err = p.exchangeQUIC(req, conn)
	}

	if err != nil {
		// If we're unable to exchange messages, make sure the connection is
		// closed and signal about an internal error.
		p.closeConnWithError(conn, err)
	}

	return resp, err
}

// Close implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Close() (err error) {
	p.connMu.Lock()
	defer p.connMu.Unlock()

	runtime.SetFinalizer(p, nil)

	if p.conn != nil {
		err = p.conn.CloseWithError(QUICCodeNoError, "")
	}

	return err
}

// exchangeQUIC attempts to open a new QUIC stream, send the DNS message
// through it and return the response it got from the server.
func (p *dnsOverQUIC) exchangeQUIC(req *dns.Msg, conn quic.Connection) (resp *dns.Msg, err error) {
	addr := p.Address()

	logBegin(p.logger, addr, networkUDP, req)
	defer func() { logFinish(p.logger, addr, networkUDP, err) }()

	buf, err := req.Pack()
	if err != nil {
		return nil, fmt.Errorf("failed to pack DNS message for DoQ: %w", err)
	}

	stream, err := p.openStream(conn)
	if err != nil {
		return nil, fmt.Errorf("opening stream: %w", err)
	}

	if p.timeout > 0 {
		err = stream.SetDeadline(time.Now().Add(p.timeout))
		if err != nil {
			return nil, fmt.Errorf("setting deadline: %w", err)
		}
	}

	_, err = stream.Write(proxyutil.AddPrefix(buf))
	if err != nil {
		return nil, fmt.Errorf("failed to write to a QUIC stream: %w", err)
	}

	// The client MUST send the DNS query over the selected stream, and MUST
	// indicate through the STREAM FIN mechanism that no further data will be
	// sent on that stream. Note, that stream.Close() closes the write-direction
	// of the stream, but does not prevent reading from it.
	err = stream.Close()
	if err != nil {
		p.logger.Debug("closing quic stream", slogutil.KeyError, err)
	}

	return p.readMsg(stream)
}

// getBytesPool returns (creates if needed) a pool we store byte buffers in.
func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) {
	p.bytesPoolMu.Lock()
	defer p.bytesPoolMu.Unlock()

	if p.bytesPool == nil {
		p.bytesPool = &sync.Pool{
			New: func() interface{} {
				b := make([]byte, dns.MaxMsgSize)

				return &b
			},
		}
	}

	return p.bytesPool
}

// getConnection opens or returns an existing quic.Connection and indicates
// whether it opened a new connection or used an existing cached one.
func (p *dnsOverQUIC) getConnection() (conn quic.Connection, cached bool, err error) {
	p.connMu.Lock()
	defer p.connMu.Unlock()

	conn = p.conn
	if conn != nil {
		return conn, true, nil
	}

	conn, err = p.openConnection()
	if err != nil {
		return nil, false, err
	}

	p.conn = conn

	return conn, false, nil
}

// getQUICConfig returns the QUIC config in a thread-safe manner.  Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverQUIC) getQUICConfig() (c *quic.Config) {
	p.quicConfigMu.Lock()
	defer p.quicConfigMu.Unlock()

	return p.quicConfig
}

// resetQUICConfig re-creates the tokens store as we may need to use a new one
// if we failed to connect.
func (p *dnsOverQUIC) resetQUICConfig() {
	p.quicConfigMu.Lock()
	defer p.quicConfigMu.Unlock()

	p.quicConfig = p.quicConfig.Clone()
	p.quicConfig.TokenStore = newQUICTokenStore()
}

// openStream opens a new QUIC stream for the specified connection.
func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) {
	ctx, cancel := p.withDeadline(context.Background())
	defer cancel()

	stream, err := conn.OpenStreamSync(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to open a QUIC stream: %w", err)
	}

	return stream, nil
}

// openConnection dials a new QUIC connection.
func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
	dialContext, err := p.getDialer()
	if err != nil {
		return nil, fmt.Errorf("bootstrapping %s: %w", p.addr, err)
	}

	// we're using bootstrapped address instead of what's passed to the function
	// it does not create an actual connection, but it helps us determine
	// what IP is actually reachable (when there're v4/v6 addresses).
	rawConn, err := dialContext(context.Background(), "udp", "")
	if err != nil {
		return nil, fmt.Errorf("dialing raw connection to %s: %w", p.addr, err)
	}

	// It's never actually used.
	err = rawConn.Close()
	if err != nil {
		p.logger.Debug("closing raw connection", "addr", p.addr, slogutil.KeyError, err)
	}

	udpConn, ok := rawConn.(*net.UDPConn)
	if !ok {
		return nil, fmt.Errorf("unexpected type %T of connection; should be %T", rawConn, udpConn)
	}

	addr := udpConn.RemoteAddr().String()

	ctx, cancel := p.withDeadline(context.Background())
	defer cancel()

	conn, err = quic.DialAddrEarly(ctx, addr, p.tlsConf.Clone(), p.getQUICConfig())
	if err != nil {
		return nil, fmt.Errorf("dialing quic connection to %s: %w", p.addr, err)
	}

	return conn, nil
}

// closeConnWithError closes the active connection with error to make sure that
// new queries were processed in another connection.  We can do that in the case
// of a fatal error.
func (p *dnsOverQUIC) closeConnWithError(conn quic.Connection, err error) {
	p.connMu.Lock()
	defer p.connMu.Unlock()

	code := QUICCodeNoError
	if err != nil {
		code = QUICCodeInternalError
	}

	if errors.Is(err, quic.Err0RTTRejected) {
		// Reset the TokenStore only if 0-RTT was rejected.
		p.resetQUICConfig()
	}

	err = conn.CloseWithError(code, "")
	if err != nil {
		p.logger.Error("failed to close the conn", slogutil.KeyError, err)
	}

	// If the connection that's being closed is cached, reset the cache.
	if p.conn == conn {
		p.conn = nil
	}
}

// readMsg reads the incoming DNS message from the QUIC stream.
func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *dns.Msg, err error) {
	pool := p.getBytesPool()
	bufPtr := pool.Get().(*[]byte)

	defer pool.Put(bufPtr)

	respBuf := *bufPtr
	n, err := stream.Read(respBuf)
	if err != nil && n == 0 {
		return nil, fmt.Errorf("reading response from %s: %w", p.addr, err)
	}

	stream.CancelRead(0)

	// All DNS messages (queries and responses) sent over DoQ connections MUST
	// be encoded as a 2-octet length field followed by the message content as
	// specified in [RFC1035].
	// IMPORTANT: Note, that we ignore this prefix here as this implementation
	// does not support receiving multiple messages over a single connection.
	m = new(dns.Msg)
	err = m.Unpack(respBuf[2:])
	if err != nil {
		return nil, fmt.Errorf("unpacking response from %s: %w", p.addr, err)
	}

	return m, nil
}

// newQUICTokenStore creates a new quic.TokenStore that is necessary to have
// in order to benefit from 0-RTT.
func newQUICTokenStore() (s quic.TokenStore) {
	// You can read more on address validation here:
	// https://datatracker.ietf.org/doc/html/rfc9000#section-8.1
	// Setting maxOrigins to 1 and tokensPerOrigin to 10 assuming that this is
	// more than enough for the way we use it (one connection per upstream).
	return quic.NewLRUTokenStore(1, 10)
}

// isQUICRetryError checks the error and determines whether it may signal that
// we should re-create the QUIC connection.  This requirement is caused by
// quic-go issues, see the comments inside this function.
// TODO(ameshkov): re-test when updating quic-go.
func isQUICRetryError(err error) (ok bool) {
	var qAppErr *quic.ApplicationError
	if errors.As(err, &qAppErr) {
		// Error code 0 is often returned when the server has been restarted,
		// and we try to use the same connection on the client-side.
		// http3.ErrCodeNoError may be used by an HTTP/3 server when closing
		// an idle connection.  These connections are not immediately closed
		// by the HTTP client so this case should be handled.
		if qAppErr.ErrorCode == 0 ||
			qAppErr.ErrorCode == quic.ApplicationErrorCode(http3.ErrCodeNoError) {
			return true
		}
	}

	var qIdleErr *quic.IdleTimeoutError
	if errors.As(err, &qIdleErr) {
		// This error means that the connection was closed due to being idle.
		// In this case we should forcibly re-create the QUIC connection.
		// Reproducing is rather simple, stop the server and wait for 30 seconds
		// then try to send another request via the same upstream.
		return true
	}

	var resetErr *quic.StatelessResetError
	if errors.As(err, &resetErr) {
		// A stateless reset is sent when a server receives a QUIC packet that
		// it doesn't know how to decrypt.  For instance, it may happen when
		// the server was recently rebooted.  We should reconnect and try again
		// in this case.
		return true
	}

	var qTransportError *quic.TransportError
	if errors.As(err, &qTransportError) && qTransportError.ErrorCode == quic.NoError {
		// A transport error with the NO_ERROR error code could be sent by the
		// server when it considers that it's time to close the connection.
		// For example, Google DNS eventually closes an active connection with
		// the NO_ERROR code and "Connection max age expired" message:
		// https://github.com/AdguardTeam/dnsproxy/issues/283
		return true
	}

	if errors.Is(err, quic.Err0RTTRejected) {
		// This error happens when we try to establish a 0-RTT connection with
		// a token the server is no more aware of.  This can be reproduced by
		// restarting the QUIC server (it will clear its tokens cache).  The
		// next connection attempt will return this error until the client's
		// tokens cache is purged.
		return true
	}

	if errors.Is(err, os.ErrDeadlineExceeded) {
		// A timeout that could happen when the server has been restarted.
		return true
	}

	return false
}

func (p *dnsOverQUIC) withDeadline(
	parent context.Context,
) (ctx context.Context, cancel context.CancelFunc) {
	ctx, cancel = parent, func() {}
	if p.timeout > 0 {
		ctx, cancel = context.WithDeadline(ctx, time.Now().Add(p.timeout))
	}

	return ctx, cancel
}
07070100000099000081A4000000000000000000000001679A649F00002EC1000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/upstream/doq_internal_test.gopackage upstream

import (
	"context"
	"crypto/tls"
	"encoding/binary"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/netip"
	"net/url"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/logging"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestUpstreamDoQ(t *testing.T) {
	tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")

	srv := startDoQServer(t, tlsConf, 0)

	address := fmt.Sprintf("quic://%s", srv.addr)
	var lastState tls.ConnectionState
	opts := &Options{
		Logger: slogutil.NewDiscardLogger(),
		VerifyConnection: func(state tls.ConnectionState) error {
			lastState = state

			return nil
		},
		RootCAs: rootCAs,
	}
	u, err := AddressToUpstream(address, opts)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	uq := u.(*dnsOverQUIC)
	var conn quic.Connection

	// Test that it responds properly
	for range 10 {
		checkUpstream(t, u, address)

		if conn == nil {
			conn = uq.conn
		} else {
			// This way we test that the connection is properly reused.
			require.Equal(t, conn, uq.conn)
		}
	}

	// Close the connection (make sure that we re-establish the connection).
	_ = conn.CloseWithError(quic.ApplicationErrorCode(0), "")

	// Try to establish it again.
	checkUpstream(t, u, address)

	// Make sure that the session has been resumed.
	require.True(t, lastState.DidResume)

	// Re-create the upstream to make the test check initialization and
	// check it for race conditions.
	u, err = AddressToUpstream(address, opts)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	checkRaceCondition(u)
}

func TestUpstream_Exchange_quicServerCloseConn(t *testing.T) {
	// Use the same tlsConf for all servers to preserve the data necessary for
	// 0-RTT connections.
	tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")

	// Run the first server instance.
	srv := startDoQServer(t, tlsConf, 0)

	// Create a DNS-over-QUIC upstream.
	address := fmt.Sprintf("quic://%s", srv.addr)
	u, err := AddressToUpstream(address, &Options{
		Logger:  slogutil.NewDiscardLogger(),
		RootCAs: rootCAs,
	})

	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Test that the upstream works properly.
	checkUpstream(t, u, address)

	// Close all active connections.
	err = srv.closeConns()
	require.NoError(t, err)

	// Now run several queries in parallel to check that the error from the
	// following issue is not happening:
	// https://github.com/AdguardTeam/dnsproxy/issues/389.
	//
	// Run 10 queries in parallel as the initial testing showed that this is
	// enough to trigger the race issue.
	const parallelQueries = 10

	wg := sync.WaitGroup{}
	wg.Add(parallelQueries)

	for i := 0; i < 10; i++ {
		pt := testutil.PanicT{}

		go func(t assert.TestingT) {
			defer wg.Done()

			req := createTestMessage()
			_, errExch := u.Exchange(req)

			assert.NoError(t, errExch)
		}(pt)
	}

	wg.Wait()
}

func TestUpstreamDoQ_serverRestart(t *testing.T) {
	t.Parallel()

	// Use the same tlsConf for all servers to preserve the data necessary for
	// 0-RTT connections.
	tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")

	var addr netip.AddrPort
	var upsStr string
	var u Upstream

	t.Run("first_try", func(t *testing.T) {
		srv := startDoQServer(t, tlsConf, 0)

		addr = netip.MustParseAddrPort(srv.addr)
		upsStr = (&url.URL{
			Scheme: "quic",
			Host:   addr.String(),
		}).String()

		var err error
		u, err = AddressToUpstream(
			upsStr,
			&Options{
				Logger:  slogutil.NewDiscardLogger(),
				RootCAs: rootCAs,
				Timeout: 100 * time.Millisecond,
			},
		)
		require.NoError(t, err)

		checkUpstream(t, u, upsStr)
	})
	require.False(t, t.Failed())
	testutil.CleanupAndRequireSuccess(t, u.Close)

	t.Run("second_try", func(t *testing.T) {
		_ = startDoQServer(t, tlsConf, int(addr.Port()))

		checkUpstream(t, u, upsStr)
	})
	require.False(t, t.Failed())

	t.Run("retry", func(t *testing.T) {
		_, err := u.Exchange(createTestMessage())
		require.Error(t, err)

		_ = startDoQServer(t, tlsConf, int(addr.Port()))

		checkUpstream(t, u, upsStr)
	})
}

func TestUpstreamDoQ_0RTT(t *testing.T) {
	tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")

	srv := startDoQServer(t, tlsConf, 0)

	tracer := &quicTracer{}
	address := fmt.Sprintf("quic://%s", srv.addr)
	u, err := AddressToUpstream(address, &Options{
		Logger:     slogutil.NewDiscardLogger(),
		QUICTracer: tracer.TracerForConnection,
		RootCAs:    rootCAs,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	uq := u.(*dnsOverQUIC)
	req := createTestMessage()

	// Trigger connection to a QUIC server.
	resp, err := uq.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, resp)

	// Close the active connection to make sure we'll reconnect.
	func() {
		uq.connMu.Lock()
		defer uq.connMu.Unlock()

		err = uq.conn.CloseWithError(QUICCodeNoError, "")
		require.NoError(t, err)

		uq.conn = nil
	}()

	// Trigger second connection.
	resp, err = uq.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, resp)

	// Check traced connections info.
	conns := tracer.getConnectionsInfo()
	require.Len(t, conns, 2)

	// Examine the first connection (no 0-RTT there).
	require.False(t, conns[0].is0RTT())

	// Examine the second connection (the one that used 0-RTT).
	require.True(t, conns[1].is0RTT())
}

// testDoHServer is an instance of a test DNS-over-QUIC server.
type testDoQServer struct {
	// listener is the QUIC connections listener.
	listener *quic.EarlyListener

	// logger is used for serving errors logging.
	logger *slog.Logger

	// conns is the list of connections that are currently active.
	conns map[quic.EarlyConnection]struct{}

	// connsMu protects conns.
	connsMu *sync.Mutex

	// addr is the address that this server listens to.
	addr string
}

// Shutdown stops the test server.
func (s *testDoQServer) Shutdown() (err error) {
	errConns := s.closeConns()
	errListener := s.listener.Close()

	return errors.Join(errConns, errListener)
}

// Serve serves DoQ requests.
func (s *testDoQServer) Serve() {
	for {
		var conn quic.EarlyConnection
		var err error
		func() {
			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
			defer cancel()
			conn, err = s.listener.Accept(ctx)
		}()
		if err != nil {
			if errors.Is(err, quic.ErrServerClosed) {
				s.logger.Debug("accept failed", slogutil.KeyError, err)
			} else {
				s.logger.Error("accept failed", slogutil.KeyError, err)
			}

			return
		}

		go s.handleQUICConnection(conn)
	}
}

// handleQUICConnection handles incoming QUIC connection.
func (s *testDoQServer) handleQUICConnection(conn quic.EarlyConnection) {
	s.addConn(conn)
	defer s.closeConn(conn)

	for {
		ctx := context.Background()

		stream, err := conn.AcceptStream(ctx)
		if err != nil {
			return
		}

		go func() {
			qErr := s.handleQUICStream(ctx, stream)
			if qErr != nil {
				s.logger.Error("handling", "raddr", conn.RemoteAddr(), slogutil.KeyError, qErr)

				_ = conn.CloseWithError(QUICCodeNoError, "")
			}
		}()
	}
}

// handleQUICStream handles new QUIC streams, reads DNS messages and responds to
// them.
func (s *testDoQServer) handleQUICStream(ctx context.Context, stream quic.Stream) (err error) {
	defer slogutil.CloseAndLog(ctx, s.logger, stream, slog.LevelDebug)

	buf := make([]byte, dns.MaxMsgSize+2)
	_, err = stream.Read(buf)
	if err != nil && err != io.EOF {
		return err
	}

	stream.CancelRead(0)

	req := &dns.Msg{}
	packetLen := binary.BigEndian.Uint16(buf[:2])
	err = req.Unpack(buf[2 : packetLen+2])
	if err != nil {
		return err
	}

	resp := respondToTestMessage(req)

	buf, err = resp.Pack()
	if err != nil {
		return err
	}

	buf = proxyutil.AddPrefix(buf)
	_, err = stream.Write(buf)

	return err
}

// addConn adds conn to the list of active connections.
func (s *testDoQServer) addConn(conn quic.EarlyConnection) {
	s.connsMu.Lock()
	defer s.connsMu.Unlock()

	s.conns[conn] = struct{}{}
}

// closeConn closes the specified QUIC connection.
func (s *testDoQServer) closeConn(conn quic.EarlyConnection) {
	s.connsMu.Lock()
	defer s.connsMu.Unlock()

	err := conn.CloseWithError(QUICCodeNoError, "")
	if err != nil {
		s.logger.Debug("failed to close conn", slogutil.KeyError, err)
	}

	delete(s.conns, conn)
}

// closeConns closes all active connections.
func (s *testDoQServer) closeConns() (err error) {
	s.connsMu.Lock()
	defer s.connsMu.Unlock()

	var errs []error

	for conn := range s.conns {
		errConn := conn.CloseWithError(QUICCodeNoError, "")
		if errConn != nil {
			errs = append(errs, errConn)
		}

		delete(s.conns, conn)
	}

	return errors.Join(errs...)
}

// startDoQServer starts a test DoQ server.  Note that it adds its own shutdown
// to cleanup of t.
func startDoQServer(t *testing.T, tlsConf *tls.Config, port int) (s *testDoQServer) {
	tlsConf.NextProtos = []string{NextProtoDQ}

	udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", port))
	require.NoError(t, err)

	conn, err := net.ListenUDP("udp", udpAddr)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, conn.Close)

	transport := &quic.Transport{
		Conn: conn,
		// Necessary for 0-RTT.
		VerifySourceAddress: func(a net.Addr) bool {
			return true
		},
	}

	listen, err := transport.ListenEarly(
		tlsConf,
		&quic.Config{
			Allow0RTT: true,
		},
	)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, transport.Close)

	s = &testDoQServer{
		addr:     listen.Addr().String(),
		listener: listen,
		// TODO(d.kolyshev): Add a concurrent safe [slog.Handler] wrapper for
		// [testing.TB] log function.
		logger:  slogutil.NewDiscardLogger(),
		conns:   map[quic.EarlyConnection]struct{}{},
		connsMu: &sync.Mutex{},
	}

	go s.Serve()
	testutil.CleanupAndRequireSuccess(t, s.Shutdown)

	return s
}

// quicTracer implements the logging.Tracer interface.
type quicTracer struct {
	tracers []*quicConnTracer

	// mu protects fields of *quicTracer and also protects fields of every
	// nested *quicConnTracer.
	mu sync.Mutex
}

// TracerForConnection implements the logging.Tracer interface for *quicTracer.
func (q *quicTracer) TracerForConnection(
	_ context.Context,
	_ logging.Perspective,
	odcid logging.ConnectionID,
) (connTracer *logging.ConnectionTracer) {
	q.mu.Lock()
	defer q.mu.Unlock()

	tracer := &quicConnTracer{id: odcid, parent: q}
	q.tracers = append(q.tracers, tracer)

	return &logging.ConnectionTracer{
		SentLongHeaderPacket: tracer.SentLongHeaderPacket,
	}
}

// connInfo contains information about packets that we've logged.
type connInfo struct {
	packets []logging.Header
	id      logging.ConnectionID
}

// is0RTT returns true if this connection's packets contain 0-RTT packets.
func (c *connInfo) is0RTT() (ok bool) {
	for _, packet := range c.packets {
		hdr := packet
		packetType := logging.PacketTypeFromHeader(&hdr)
		if packetType == logging.PacketType0RTT {
			return true
		}
	}

	return false
}

// getConnectionsInfo returns the traced connections' information.
func (q *quicTracer) getConnectionsInfo() (conns []connInfo) {
	q.mu.Lock()
	defer q.mu.Unlock()

	for _, tracer := range q.tracers {
		conns = append(conns, connInfo{
			id:      tracer.id,
			packets: tracer.packets,
		})
	}

	return conns
}

// quicConnTracer implements the logging.ConnectionTracer interface.
type quicConnTracer struct {
	parent  *quicTracer
	packets []logging.Header
	id      logging.ConnectionID
}

// SentLongHeaderPacket implements the logging.ConnectionTracer interface for
// *quicConnTracer.
func (q *quicConnTracer) SentLongHeaderPacket(
	hdr *logging.ExtendedHeader,
	_ logging.ByteCount,
	_ logging.ECN,
	_ *logging.AckFrame,
	_ []logging.Frame,
) {
	q.parent.mu.Lock()
	defer q.parent.mu.Unlock()

	q.packets = append(q.packets, hdr.Header)
}
0707010000009A000081A4000000000000000000000001679A649F00001BA3000000000000000000000000000000000000002000000000dnsproxy-0.75.0/upstream/dot.gopackage upstream

import (
	"context"
	"crypto/tls"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/url"
	"os"
	"runtime"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
)

// dialTimeout is the global timeout for establishing a TLS connection.
// TODO(ameshkov): use bootstrap timeout instead.
const dialTimeout = 10 * time.Second

// dnsOverTLS implements the [Upstream] interface for the DNS-over-TLS protocol.
type dnsOverTLS struct {
	// addr is the DNS-over-TLS server URL.
	addr *url.URL

	// getDialer either returns an initialized dial handler or creates a
	// new one.
	getDialer DialerInitializer

	// tlsConf is the configuration of TLS.
	tlsConf *tls.Config

	// connsMu protects conns.
	connsMu *sync.Mutex

	// logger is used for exchange logging.  It is never nil.
	logger *slog.Logger

	// conns stores the connections ready for reuse.  Don't use [sync.Pool]
	// here, since there is no need to deallocate these connections.
	//
	// TODO(e.burkov, ameshkov):  Currently connections just stored in FILO
	// order, which eventually makes most of them unusable due to timeouts.
	// This leads to weak performance for all exchanges coming across such
	// connections.
	conns []net.Conn
}

// newDoT returns the DNS-over-TLS Upstream.
func newDoT(addr *url.URL, opts *Options) (ups Upstream, err error) {
	addPort(addr, defaultPortDoT)

	tlsUps := &dnsOverTLS{
		addr:      addr,
		getDialer: newDialerInitializer(addr, opts),
		tlsConf: &tls.Config{
			ServerName:   addr.Hostname(),
			RootCAs:      opts.RootCAs,
			CipherSuites: opts.CipherSuites,
			// Use the default capacity for the LRU cache.  It may be useful to
			// store several caches since the user may be routed to different
			// servers in case there's load balancing on the server-side.
			ClientSessionCache: tls.NewLRUClientSessionCache(0),
			MinVersion:         tls.VersionTLS12,
			// #nosec G402 -- TLS certificate verification could be disabled by
			// configuration.
			InsecureSkipVerify:    opts.InsecureSkipVerify,
			VerifyPeerCertificate: opts.VerifyServerCertificate,
			VerifyConnection:      opts.VerifyConnection,
		},
		connsMu: &sync.Mutex{},
		logger:  opts.Logger,
	}

	runtime.SetFinalizer(tlsUps, (*dnsOverTLS).Close)

	return tlsUps, nil
}

// type check
var _ Upstream = (*dnsOverTLS)(nil)

// Address implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Address() string { return p.addr.String() }

// Exchange implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Exchange(req *dns.Msg) (reply *dns.Msg, err error) {
	h, err := p.getDialer()
	if err != nil {
		return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
	}

	conn, err := p.conn(h)
	if err != nil {
		return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
	}

	reply, err = p.exchangeWithConn(conn, req)
	if err != nil {
		// The pooled connection might have been closed already, see
		// https://github.com/AdguardTeam/dnsproxy/issues/3.  The following
		// connection from pool may also be malformed, so dial a new one.

		err = errors.WithDeferred(err, conn.Close())
		p.logger.Debug("dot got bad conn from pool", "addr", p.addr, slogutil.KeyError, err)

		// Retry.
		conn, err = tlsDial(h, p.tlsConf.Clone())
		if err != nil {
			return nil, fmt.Errorf(
				"dialing %s: connecting to %s: %w",
				p.addr,
				p.tlsConf.ServerName,
				err,
			)
		}

		reply, err = p.exchangeWithConn(conn, req)
		if err != nil {
			return reply, errors.WithDeferred(err, conn.Close())
		}
	}

	p.putBack(conn)

	return reply, nil
}

// Close implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Close() (err error) {
	runtime.SetFinalizer(p, nil)

	p.connsMu.Lock()
	defer p.connsMu.Unlock()

	var closeErrs []error
	for _, conn := range p.conns {
		closeErr := conn.Close()
		if closeErr != nil && isCriticalTCP(closeErr) {
			closeErrs = append(closeErrs, closeErr)
		}
	}

	return errors.Join(closeErrs...)
}

// conn returns the first available connection from the pool if there is any, or
// dials a new one otherwise.
func (p *dnsOverTLS) conn(h bootstrap.DialHandler) (conn net.Conn, err error) {
	// Dial a new connection outside the lock, if needed.
	defer func() {
		if conn == nil {
			conn, err = tlsDial(h, p.tlsConf.Clone())
			err = errors.Annotate(err, "connecting to %s: %w", p.tlsConf.ServerName)
		}
	}()

	p.connsMu.Lock()
	defer p.connsMu.Unlock()

	l := len(p.conns)
	if l == 0 {
		return nil, nil
	}

	p.conns, conn = p.conns[:l-1], p.conns[l-1]

	err = conn.SetDeadline(time.Now().Add(dialTimeout))
	if err != nil {
		p.logger.Debug("dot upstream setting deadline to conn from pool", slogutil.KeyError, err)

		// If deadLine can't be updated it means that connection was already
		// closed.
		return nil, nil
	}

	p.logger.Debug("dot upstream using existing conn", "raddr", conn.RemoteAddr())

	return conn, nil
}

func (p *dnsOverTLS) putBack(conn net.Conn) {
	p.connsMu.Lock()
	defer p.connsMu.Unlock()

	p.conns = append(p.conns, conn)
}

// exchangeWithConn tries to exchange the query using conn.
func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, req *dns.Msg) (reply *dns.Msg, err error) {
	addr := p.Address()

	logBegin(p.logger, addr, networkTCP, req)
	defer func() { logFinish(p.logger, addr, networkTCP, err) }()

	dnsConn := dns.Conn{Conn: conn}

	err = dnsConn.WriteMsg(req)
	if err != nil {
		return nil, fmt.Errorf("sending request to %s: %w", addr, err)
	}

	reply, err = dnsConn.ReadMsg()
	if err != nil {
		return nil, fmt.Errorf("reading response from %s: %w", addr, err)
	} else if reply.Id != req.Id {
		return reply, dns.ErrId
	}

	return reply, err
}

// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func tlsDial(dialContext bootstrap.DialHandler, conf *tls.Config) (c *tls.Conn, err error) {
	// We're using bootstrapped address instead of what's passed to the
	// function.
	rawConn, err := dialContext(context.Background(), networkTCP, "")
	if err != nil {
		return nil, err
	}

	// We want the timeout to cover the whole process: TCP connection and TLS
	// handshake dialTimeout will be used as connection deadLine.
	conn := tls.Client(rawConn, conf)
	err = conn.SetDeadline(time.Now().Add(dialTimeout))
	if err != nil {
		// Must not happen in normal circumstances.
		panic(fmt.Errorf("dnsproxy: tls dial: setting deadline: %w", err))
	}

	err = conn.Handshake()
	if err != nil {
		return nil, errors.WithDeferred(err, conn.Close())
	}

	return conn, nil
}

// isCriticalTCP returns true if err isn't an expected error in terms of closing
// the TCP connection.
func isCriticalTCP(err error) (ok bool) {
	var netErr net.Error
	if errors.As(err, &netErr) && netErr.Timeout() {
		return false
	}

	switch {
	case
		errors.Is(err, io.EOF),
		errors.Is(err, net.ErrClosed),
		errors.Is(err, os.ErrDeadlineExceeded),
		isConnBroken(err):
		return false
	default:
		return true
	}
}
0707010000009B000081A4000000000000000000000001679A649F00001E54000000000000000000000000000000000000002E00000000dnsproxy-0.75.0/upstream/dot_internal_test.gopackage upstream

import (
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"net"
	"net/url"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestUpstream_dnsOverTLS(t *testing.T) {
	srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		resp := respondToTestMessage(req)

		err := w.WriteMsg(resp)

		pt := testutil.PanicT{}
		require.NoError(pt, err)
	})

	// Create a DoT upstream that we'll be testing.
	addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
	u, err := AddressToUpstream(addr, &Options{
		Logger:             slogutil.NewDiscardLogger(),
		InsecureSkipVerify: true,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Test that it responds properly.
	for range 10 {
		checkUpstream(t, u, addr)
	}
}

func TestUpstream_dnsOverTLS_race(t *testing.T) {
	const count = 10

	srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		resp := respondToTestMessage(req)

		err := w.WriteMsg(resp)

		pt := testutil.PanicT{}
		require.NoError(pt, err)
	})

	// Creating a DoT upstream that we will be testing.
	addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
	u, err := AddressToUpstream(addr, &Options{
		Logger:             slogutil.NewDiscardLogger(),
		InsecureSkipVerify: true,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Use this upstream from multiple goroutines in parallel.
	wg := sync.WaitGroup{}
	for range count {
		wg.Add(1)
		go func() {
			defer wg.Done()

			pt := testutil.PanicT{}

			req := createTestMessage()
			resp, uErr := u.Exchange(req)
			require.NoError(pt, uErr)
			requireResponse(pt, req, resp)
		}()
	}

	wg.Wait()
}

// TODO(e.burkov, a.garipov):  Add to golibs and use here some kind of helper
// for type assertion of interface types.
func TestUpstream_dnsOverTLS_poolReconnect(t *testing.T) {
	srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(req)))
	})

	// This var is used to store the last connection state in order to check
	// if session resumption works as expected.
	var lastState tls.ConnectionState

	// Init the upstream to the test DoT server that also keeps track of the
	// session resumptions.
	addr := (&url.URL{
		Scheme: "tls",
		Host:   srv.srv.Listener.Addr().String(),
	}).String()
	u, err := AddressToUpstream(addr, &Options{
		Logger:             slogutil.NewDiscardLogger(),
		InsecureSkipVerify: true,
		VerifyConnection: func(state tls.ConnectionState) error {
			lastState = state

			return nil
		},
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	p := testutil.RequireTypeAssert[*dnsOverTLS](t, u)

	// Send the first test message.
	req := createTestMessage()
	reply, err := u.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, reply)

	// Now let's close the pooled connection.
	require.Len(t, p.conns, 1)
	conn := p.conns[0]
	require.NoError(t, conn.Close())

	// Send the second test message.
	req = createTestMessage()
	reply, err = u.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, reply)

	// Now assert that the number of connections in the pool is not changed.
	require.Len(t, p.conns, 1)
	assert.NotSame(t, conn, p.conns[0])

	// Check that the session was resumed on the last attempt.
	assert.True(t, lastState.DidResume)
}

func TestUpstream_dnsOverTLS_poolDeadline(t *testing.T) {
	srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(req)))
	})

	// Create a DoT upstream that we'll be testing.
	addr := (&url.URL{
		Scheme: "tls",
		Host:   srv.srv.Listener.Addr().String(),
	}).String()
	u, err := AddressToUpstream(addr, &Options{
		Logger:             slogutil.NewDiscardLogger(),
		InsecureSkipVerify: true,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Send the first test message.
	req := createTestMessage()
	response, err := u.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, response)

	p := testutil.RequireTypeAssert[*dnsOverTLS](t, u)

	// Now let's get connection from the pool and use it again.
	require.Len(t, p.conns, 1)
	conn := p.conns[0]

	dialHandler, err := p.getDialer()
	require.NoError(t, err)

	usedConn, err := p.conn(dialHandler)
	require.NoError(t, err)
	require.Same(t, usedConn, conn)

	response, err = p.exchangeWithConn(conn, req)
	require.NoError(t, err)
	requireResponse(t, req, response)

	// Update the connection's deadLine.
	err = conn.SetDeadline(time.Now().Add(10 * time.Hour))
	require.NoError(t, err)

	p.putBack(conn)

	// Get connection from the pool and reuse it.
	require.Len(t, p.conns, 1)
	conn = p.conns[0]

	usedConn, err = p.conn(dialHandler)
	require.NoError(t, err)
	require.Same(t, usedConn, conn)

	response, err = p.exchangeWithConn(usedConn, req)
	require.NoError(t, err)
	requireResponse(t, req, response)

	// Set connection's deadLine to the past and try to reuse it.
	err = usedConn.SetDeadline(time.Now().Add(-10 * time.Hour))
	require.NoError(t, err)

	// Connection with expired deadLine can't be used.
	response, err = p.exchangeWithConn(usedConn, req)
	require.Error(t, err)
	require.Nil(t, response)
}

// testDoTServer is a test DNS-over-TLS server that can be used in unit-tests.
type testDoTServer struct {
	// srv is the *dns.Server instance that listens for DoT requests.
	srv *dns.Server

	// tlsConfig is the TLS configuration that is used for this server.
	tlsConfig *tls.Config

	// rootCAs is the pool with root certificates used by the test server.
	rootCAs *x509.CertPool

	// port to which the server listens to.
	port int
}

// type check
var _ io.Closer = (*testDoTServer)(nil)

// startDoTServer starts *testDoTServer on a random port.
//
// TODO(e.burkov):  Also return address?
func startDoTServer(tb testing.TB, handler dns.HandlerFunc) (s *testDoTServer) {
	tb.Helper()

	tcpListener, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(tb, err)

	tlsConfig, rootCAs := createServerTLSConfig(tb, "127.0.0.1")
	tlsListener := tls.NewListener(tcpListener, tlsConfig)

	srv := &dns.Server{
		Listener:  tlsListener,
		TLSConfig: tlsConfig,
		Net:       "tls",
		Handler:   handler,
	}

	go func() {
		pt := testutil.PanicT{}
		require.NoError(pt, srv.ActivateAndServe())
	}()

	s = &testDoTServer{
		srv:       srv,
		tlsConfig: tlsConfig,
		rootCAs:   rootCAs,
		port:      tcpListener.Addr().(*net.TCPAddr).Port,
	}
	testutil.CleanupAndRequireSuccess(tb, s.Close)

	return s
}

// Close implements the io.Closer interface for *testDoTServer.
func (s *testDoTServer) Close() error {
	return s.srv.Shutdown()
}

func BenchmarkDoTUpstream(b *testing.B) {
	srv := startDoTServer(b, func(w dns.ResponseWriter, m *dns.Msg) {
		err := w.WriteMsg(respondToTestMessage(m))
		require.NoError(testutil.PanicT{}, err)
	})

	addr := (&url.URL{
		Scheme: "tls",
		Host:   srv.srv.Listener.Addr().String(),
	}).String()

	u, err := AddressToUpstream(addr, &Options{
		Logger:             slogutil.NewDiscardLogger(),
		InsecureSkipVerify: true,
	})
	require.NoError(b, err)
	testutil.CleanupAndRequireSuccess(b, u.Close)

	reqChan := make(chan *dns.Msg, 64)
	go func() {
		for {
			reqChan <- createTestMessage()
		}
	}()

	// Wait for channel to fill.
	require.Eventually(b, func() bool {
		return len(reqChan) == cap(reqChan)
	}, time.Second, time.Millisecond)

	b.Run("exchange_p", func(b *testing.B) {
		b.ResetTimer()
		b.ReportAllocs()

		b.RunParallel(func(p *testing.PB) {
			for p.Next() {
				_, _ = u.Exchange(<-reqChan)
			}
		})
	})
}
0707010000009C000081A4000000000000000000000001679A649F00000152000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/dot_unix.go//go:build darwin || freebsd || linux || openbsd || netbsd

package upstream

import (
	"github.com/AdguardTeam/golibs/errors"
	"golang.org/x/sys/unix"
)

// isConnBroken returns true if err means that a connection is broken.
func isConnBroken(err error) (ok bool) {
	return errors.Is(err, unix.EPIPE) || errors.Is(err, unix.ETIMEDOUT)
}
0707010000009D000081A4000000000000000000000001679A649F00000141000000000000000000000000000000000000002800000000dnsproxy-0.75.0/upstream/dot_windows.go//go:build windows

package upstream

import (
	"github.com/AdguardTeam/golibs/errors"
	"golang.org/x/sys/windows"
)

// isConnBroken returns true if err means that a connection is broken.
func isConnBroken(err error) (ok bool) {
	return errors.Is(err, windows.WSAECONNABORTED) || errors.Is(err, windows.WSAECONNRESET)
}
0707010000009E000081A4000000000000000000000001679A649F00000A52000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/upstream/hostsresolver.gopackage upstream

import (
	"context"
	"fmt"
	"io/fs"
	"log/slog"
	"net/netip"
	"slices"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/hostsfile"
)

// HostsResolver is a [Resolver] that looks into system hosts files, see
// [hostsfile].
type HostsResolver struct {
	// strg contains all the hosts file data needed for lookups.
	strg hostsfile.Storage
}

// NewHostsResolver is the resolver based on system hosts files.
func NewHostsResolver(hosts hostsfile.Storage) (hr *HostsResolver) {
	return &HostsResolver{
		strg: hosts,
	}
}

// NewDefaultHostsResolver returns a resolver based on system hosts files
// provided by the [hostsfile.DefaultHostsPaths] and read from rootFSys.  In
// case the file by any default path doesn't exist it adds a log debug record.
// If l is nil, [slog.Default] is used.
func NewDefaultHostsResolver(rootFSys fs.FS, l *slog.Logger) (hr *HostsResolver, err error) {
	if l == nil {
		l = slog.Default()
	}

	paths, err := hostsfile.DefaultHostsPaths()
	if err != nil {
		return nil, fmt.Errorf("getting default hosts paths: %w", err)
	}

	// The error is always nil here since no readers passed.
	strg, _ := hostsfile.NewDefaultStorage()
	for _, filename := range paths {
		err = parseHostsFile(rootFSys, strg, filename, l)
		if err != nil {
			// Don't wrap the error since it's already informative enough as is.
			return nil, err
		}
	}

	return NewHostsResolver(strg), nil
}

// parseHostsFile reads a single hosts file from fsys and parses it into hosts.
func parseHostsFile(fsys fs.FS, hosts hostsfile.Set, filename string, l *slog.Logger) (err error) {
	f, err := fsys.Open(filename)
	if err != nil {
		if errors.Is(err, fs.ErrNotExist) {
			l.Debug("hosts file does not exist", "filename", filename)

			return nil
		}

		// Don't wrap the error since it's already informative enough as is.
		return err
	}

	defer func() { err = errors.WithDeferred(err, f.Close()) }()

	return hostsfile.Parse(hosts, f, nil)
}

// type check
var _ Resolver = (*HostsResolver)(nil)

// LookupNetIP implements the [Resolver] interface for *hostsResolver.
func (hr *HostsResolver) LookupNetIP(
	context context.Context,
	network string,
	host string,
) (addrs []netip.Addr, err error) {
	var ipMatches func(netip.Addr) (ok bool)
	switch network {
	case "ip4":
		ipMatches = netip.Addr.Is4
	case "ip6":
		ipMatches = netip.Addr.Is6
	case "ip":
		return slices.Clone(hr.strg.ByName(host)), nil
	default:
		return nil, fmt.Errorf("unsupported network %q", network)
	}

	for _, addr := range hr.strg.ByName(host) {
		if ipMatches(addr) {
			addrs = append(addrs, addr)
		}
	}

	return addrs, nil
}
0707010000009F000081A4000000000000000000000001679A649F0000093C000000000000000000000000000000000000002F00000000dnsproxy-0.75.0/upstream/hostsresolver_test.gopackage upstream_test

import (
	"context"
	"net/netip"
	"testing"
	"testing/fstest"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/hostsfile"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestHostsResolver_LookupNetIP(t *testing.T) {
	const hostsData = `
1.2.3.4 host1 host2 ipv4.only
::1 host1 host2 ipv6.only
`

	var (
		v4Addr = netip.MustParseAddr("1.2.3.4")
		v6Addr = netip.MustParseAddr("::1")
	)

	paths, err := hostsfile.DefaultHostsPaths()
	require.NoError(t, err)
	require.NotEmpty(t, paths)

	fsys := fstest.MapFS{
		paths[0]: {
			Data: []byte(hostsData),
		},
	}

	hr, err := upstream.NewDefaultHostsResolver(fsys, slogutil.NewDiscardLogger())
	require.NoError(t, err)

	testCases := []struct {
		name      string
		host      string
		net       string
		wantAddrs []netip.Addr
	}{{
		name:      "canonical_any",
		host:      "host1",
		net:       "ip",
		wantAddrs: []netip.Addr{v4Addr, v6Addr},
	}, {
		name:      "canonical_v4",
		host:      "host1",
		net:       "ip4",
		wantAddrs: []netip.Addr{v4Addr},
	}, {
		name:      "canonical_v6",
		host:      "host1",
		net:       "ip6",
		wantAddrs: []netip.Addr{v6Addr},
	}, {
		name:      "alias_any",
		host:      "host2",
		net:       "ip",
		wantAddrs: []netip.Addr{v4Addr, v6Addr},
	}, {
		name:      "alias_v4",
		host:      "host2",
		net:       "ip4",
		wantAddrs: []netip.Addr{v4Addr},
	}, {
		name:      "alias_v6",
		host:      "host2",
		net:       "ip6",
		wantAddrs: []netip.Addr{v6Addr},
	}, {
		name:      "unknown_host",
		host:      "host3",
		net:       "ip",
		wantAddrs: nil,
	}, {
		name:      "family_mismatch_v4",
		host:      "ipv6.only",
		net:       "ip4",
		wantAddrs: nil,
	}, {
		name:      "family_mismatch_v6",
		host:      "ipv4.only",
		net:       "ip6",
		wantAddrs: nil,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			var addrs []netip.Addr
			addrs, err = hr.LookupNetIP(context.Background(), tc.net, tc.host)
			require.NoError(t, err)

			assert.Equal(t, tc.wantAddrs, addrs)
		})
	}

	t.Run("unsupported_network", func(t *testing.T) {
		_, err = hr.LookupNetIP(context.Background(), "ip5", "host1")
		testutil.AssertErrorMsg(t, `unsupported network "ip5"`, err)
	})
}
070701000000A0000081A4000000000000000000000001679A649F000010C9000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/parallel.gopackage upstream

import (
	"fmt"
	"slices"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/miekg/dns"
)

const (
	// ErrNoUpstreams is returned from the methods that expect at least a single
	// upstream to work with when no upstreams specified.
	ErrNoUpstreams errors.Error = "no upstream specified"

	// ErrNoReply is returned from [ExchangeAll] when no upstreams replied.
	ErrNoReply errors.Error = "no reply"
)

// ExchangeParallel returns the first successful response from one of u.  It
// returns an error if all upstreams failed to exchange the request.
func ExchangeParallel(ups []Upstream, req *dns.Msg) (reply *dns.Msg, resolved Upstream, err error) {
	upsNum := len(ups)
	switch upsNum {
	case 0:
		return nil, nil, ErrNoUpstreams
	case 1:
		return exchangeSingle(ups[0], req)
	default:
		// Go on.
	}

	resCh := make(chan any, upsNum)
	for _, f := range ups {
		// Use a copy to prevent data races, as [dns.Client] can modify the DNS
		// request during the exchange.
		//
		// TODO(s.chzhen):  Consider using buffer pool.
		copyReq := req.Copy()
		go exchangeAsync(f, copyReq, resCh)
	}

	errs := []error{}
	for range ups {
		var r *ExchangeAllResult
		r, err = receiveAsyncResult(resCh)
		if err != nil {
			if !errors.Is(err, ErrNoReply) {
				errs = append(errs, err)
			}
		} else {
			return r.Resp, r.Upstream, nil
		}
	}

	// TODO(e.burkov):  Probably it's better to return the joined error from
	// each upstream that returned no response, and get rid of multiple
	// [errors.Is] calls.  This will change the behavior though.
	if len(errs) == 0 {
		return nil, nil, errors.Error("none of upstream servers responded")
	}

	return nil, nil, errors.Join(errs...)
}

// exchangeSingle returns a successful response and resolver if a DNS lookup was
// successful.
func exchangeSingle(
	ups Upstream,
	req *dns.Msg,
) (resp *dns.Msg, resolved Upstream, err error) {
	resp, err = ups.Exchange(req)
	if err != nil {
		return nil, nil, err
	}

	return resp, ups, err
}

// ExchangeAllResult is the successful result of [ExchangeAll] for a single
// upstream.
type ExchangeAllResult struct {
	// Resp is the response DNS request resolved into.
	Resp *dns.Msg

	// Upstream is the upstream that successfully resolved the request.
	Upstream Upstream
}

// ExchangeAll returns the responses from all of u.  It returns an error only if
// all upstreams failed to exchange the request.
func ExchangeAll(ups []Upstream, req *dns.Msg) (res []ExchangeAllResult, err error) {
	upsNum := len(ups)
	switch upsNum {
	case 0:
		return nil, ErrNoUpstreams
	case 1:
		var reply *dns.Msg
		reply, err = ups[0].Exchange(req)
		if err != nil {
			return nil, err
		} else if reply == nil {
			return nil, ErrNoReply
		}

		return []ExchangeAllResult{{Upstream: ups[0], Resp: reply}}, nil
	default:
		// Go on.
	}

	res = make([]ExchangeAllResult, 0, upsNum)
	var errs []error

	resCh := make(chan any, upsNum)

	// Start exchanging concurrently.
	for _, u := range ups {
		// Use a copy to prevent data races, as [dns.Client] can modify the DNS
		// request during the exchange.
		//
		// TODO(s.chzhen):  Consider using buffer pool.
		copyReq := req.Copy()
		go exchangeAsync(u, copyReq, resCh)
	}

	// Wait for all exchanges to finish.
	for range ups {
		var r *ExchangeAllResult
		r, err = receiveAsyncResult(resCh)
		if err != nil {
			errs = append(errs, err)
		} else {
			res = append(res, *r)
		}
	}

	if len(errs) == upsNum {
		return res, fmt.Errorf("all upstreams failed: %w", errors.Join(errs...))
	}

	return slices.Clip(res), nil
}

// receiveAsyncResult receives a single result from resCh or an error from
// errCh.  It returns either a non-nil result or an error.
func receiveAsyncResult(resCh chan any) (res *ExchangeAllResult, err error) {
	switch res := (<-resCh).(type) {
	case error:
		return nil, res
	case *ExchangeAllResult:
		if res.Resp == nil {
			return nil, ErrNoReply
		}

		return res, nil
	default:
		return nil, fmt.Errorf("unexpected type %T of result", res)
	}
}

// exchangeAsync tries to resolve DNS request with one upstream and sends the
// result to respCh.
func exchangeAsync(u Upstream, req *dns.Msg, resCh chan any) {
	reply, err := u.Exchange(req)
	if err != nil {
		resCh <- err
	} else {
		resCh <- &ExchangeAllResult{Resp: reply, Upstream: u}
	}
}
070701000000A1000081A4000000000000000000000001679A649F00000D44000000000000000000000000000000000000003300000000dnsproxy-0.75.0/upstream/parallel_internal_test.gopackage upstream

import (
	"fmt"
	"net/netip"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

const (
	timeout = 2 * time.Second
)

// TestExchangeParallel launches several parallel exchanges
func TestExchangeParallel(t *testing.T) {
	upstreams := []Upstream{}
	upstreamList := []string{"1.2.3.4:55", "8.8.8.1", "8.8.8.8:53"}

	for _, s := range upstreamList {
		u, err := AddressToUpstream(s, &Options{
			Logger:  slogutil.NewDiscardLogger(),
			Timeout: timeout,
		})
		if err != nil {
			t.Fatalf("cannot create upstream: %s", err)
		}
		upstreams = append(upstreams, u)
	}

	req := createTestMessage()
	start := time.Now()
	resp, u, err := ExchangeParallel(upstreams, req)
	if err != nil {
		t.Fatalf("no response from test upstreams: %s", err)
	}

	if u.Address() != "8.8.8.8:53" {
		t.Fatalf("shouldn't happen. This upstream can't resolve DNS request: %s", u.Address())
	}

	requireResponse(t, req, resp)
	elapsed := time.Since(start)
	if elapsed > timeout {
		t.Fatalf("exchange took more time than the configured timeout: %v", elapsed)
	}
}

func TestExchangeParallelEmpty(t *testing.T) {
	ups := []Upstream{
		&testUpstream{empty: true},
		&testUpstream{empty: true},
	}

	req := createTestMessage()
	resp, up, err := ExchangeParallel(ups, req)
	require.Error(t, err)

	assert.Nil(t, resp)
	assert.Nil(t, up)
}

// testUpstream represents a mock upstream structure.
type testUpstream struct {
	// addr is a mock A record IP address to be returned.
	addr netip.Addr

	// err is a mock error to be returned.
	err bool

	// empty indicates if a nil response is returned.
	empty bool

	// sleep is a delay before response.
	sleep time.Duration
}

// type check
var _ Upstream = (*testUpstream)(nil)

// Exchange implements the [Upstream] interface for *testUpstream.
func (u *testUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	if u.sleep != 0 {
		time.Sleep(u.sleep)
	}

	if u.empty {
		return nil, nil
	}

	if u.err {
		return nil, fmt.Errorf("upstream error")
	}

	resp = &dns.Msg{}
	resp.SetReply(req)

	if u.addr != (netip.Addr{}) {
		a := dns.A{
			A: u.addr.AsSlice(),
		}

		resp.Answer = append(resp.Answer, &a)
	}

	return resp, nil
}

// Address implements the [Upstream] interface for *testUpstream.
func (u *testUpstream) Address() (addr string) {
	return ""
}

// Close implements the [Upstream] interface for *testUpstream.
func (u *testUpstream) Close() (err error) {
	return nil
}

func TestExchangeAll(t *testing.T) {
	delayedAnsAddr := netip.MustParseAddr("1.1.1.1")
	ansAddr := netip.MustParseAddr("3.3.3.3")

	ups := []Upstream{&testUpstream{
		addr:  delayedAnsAddr,
		sleep: 100 * time.Millisecond,
	}, &testUpstream{
		err: true,
	}, &testUpstream{
		addr: ansAddr,
	}}

	req := createHostTestMessage("test.org")
	res, err := ExchangeAll(ups, req)
	require.NoError(t, err)
	require.Len(t, res, 2)

	resp := res[0].Resp
	require.NotNil(t, resp)
	require.NotEmpty(t, resp.Answer)
	require.IsType(t, new(dns.A), resp.Answer[0])

	ip := resp.Answer[0].(*dns.A).A
	assert.Equal(t, ansAddr.AsSlice(), []byte(ip))

	resp = res[1].Resp
	require.NotNil(t, resp)
	require.NotEmpty(t, resp.Answer)
	require.IsType(t, new(dns.A), resp.Answer[0])

	ip = resp.Answer[0].(*dns.A).A
	assert.Equal(t, delayedAnsAddr.AsSlice(), []byte(ip))
}
070701000000A2000081A4000000000000000000000001679A649F000015FA000000000000000000000000000000000000002200000000dnsproxy-0.75.0/upstream/plain.gopackage upstream

import (
	"context"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/url"
	"strings"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
)

// network is the semantic type alias of the network to pass to dialing
// functions.  It's either [networkUDP] or [networkTCP].  It may also be used as
// URL scheme for plain upstreams.
type network = string

const (
	// networkUDP is the UDP network.
	networkUDP network = "udp"

	// networkTCP is the TCP network.
	networkTCP network = "tcp"
)

// plainDNS implements the [Upstream] interface for the regular DNS protocol.
type plainDNS struct {
	// addr is the DNS server URL.  Scheme is always "udp" or "tcp".
	addr *url.URL

	// logger is used for exchange logging.  It is never nil.
	logger *slog.Logger

	// getDialer either returns an initialized dial handler or creates a new
	// one.
	getDialer DialerInitializer

	// net is the network of the connections.
	net network

	// timeout is the timeout for DNS requests.
	timeout time.Duration
}

// newPlain returns the plain DNS Upstream.  addr.Scheme should be either "udp"
// or "tcp".
func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) {
	switch addr.Scheme {
	case networkUDP, networkTCP:
		// Go on.
	default:
		return nil, fmt.Errorf("unsupported url scheme: %s", addr.Scheme)
	}

	addPort(addr, defaultPortPlain)

	return &plainDNS{
		addr:      addr,
		logger:    opts.Logger,
		getDialer: newDialerInitializer(addr, opts),
		net:       addr.Scheme,
		timeout:   opts.Timeout,
	}, nil
}

// type check
var _ Upstream = &plainDNS{}

// Address implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Address() string {
	switch p.net {
	case networkUDP:
		return p.addr.Host
	case networkTCP:
		return p.addr.String()
	default:
		panic(fmt.Sprintf("unexpected network: %s", p.net))
	}
}

// dialExchange performs a DNS exchange with the specified dial handler.
// network must be either [networkUDP] or [networkTCP].
func (p *plainDNS) dialExchange(
	network network,
	dial bootstrap.DialHandler,
	req *dns.Msg,
) (resp *dns.Msg, err error) {
	addr := p.Address()
	client := &dns.Client{Timeout: p.timeout}

	conn := &dns.Conn{}
	if network == networkUDP {
		conn.UDPSize = dns.MinMsgSize
	}

	logBegin(p.logger, addr, network, req)
	defer func() { logFinish(p.logger, addr, network, err) }()

	ctx := context.Background()
	conn.Conn, err = dial(ctx, network, "")
	if err != nil {
		return nil, fmt.Errorf("dialing %s over %s: %w", p.addr.Host, network, err)
	}
	defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn)

	resp, _, err = client.ExchangeWithConn(req, conn)
	if isExpectedConnErr(err) {
		conn.Conn, err = dial(ctx, network, "")
		if err != nil {
			return nil, fmt.Errorf("dialing %s over %s again: %w", p.addr.Host, network, err)
		}
		defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn)

		resp, _, err = client.ExchangeWithConn(req, conn)
	}

	if err != nil {
		return resp, fmt.Errorf("exchanging with %s over %s: %w", addr, network, err)
	}

	return resp, validatePlainResponse(req, resp)
}

// isExpectedConnErr returns true if the error is expected.  In this case,
// we will make a second attempt to process the request.
func isExpectedConnErr(err error) (is bool) {
	var netErr net.Error

	return err != nil && (errors.As(err, &netErr) || errors.Is(err, io.EOF))
}

// Exchange implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	dial, err := p.getDialer()
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return nil, err
	}

	addr := p.Address()

	resp, err = p.dialExchange(p.net, dial, req)
	if p.net != networkUDP {
		// The network is already TCP.
		return resp, err
	}

	if resp == nil {
		// There is likely an error with the upstream.
		return resp, err
	}

	if errors.Is(err, errQuestion) {
		// The upstream responds with malformed messages, so try TCP.
		p.logger.Debug(
			"plain response is malformed, using tcp",
			"addr", addr,
			slogutil.KeyError, err,
		)

		return p.dialExchange(networkTCP, dial, req)
	} else if resp.Truncated {
		// Fallback to TCP on truncated responses.
		p.logger.Debug(
			"plain response is truncated, using tcp",
			"question", &req.Question[0],
			"addr", addr,
		)

		return p.dialExchange(networkTCP, dial, req)
	}

	// There is either no error or the error isn't related to the received
	// message.
	return resp, err
}

// Close implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Close() (err error) {
	return nil
}

// errQuestion is returned when a message has malformed question section.
const errQuestion errors.Error = "bad question section"

// validatePlainResponse validates resp from an upstream DNS server for
// compliance with req.  Any error returned wraps [ErrQuestion], since it
// essentially validates the question section of resp.
func validatePlainResponse(req, resp *dns.Msg) (err error) {
	if qlen := len(resp.Question); qlen != 1 {
		return fmt.Errorf("%w: only 1 question allowed; got %d", errQuestion, qlen)
	}

	reqQ, respQ := req.Question[0], resp.Question[0]

	if reqQ.Qtype != respQ.Qtype {
		return fmt.Errorf("%w: mismatched type %s", errQuestion, dns.Type(respQ.Qtype))
	}

	// Compare the names case-insensitively, just like CoreDNS does.
	if !strings.EqualFold(reqQ.Name, respQ.Name) {
		return fmt.Errorf("%w: mismatched name %q", errQuestion, respQ.Name)
	}

	return nil
}
070701000000A3000081A4000000000000000000000001679A649F0000124B000000000000000000000000000000000000003000000000dnsproxy-0.75.0/upstream/plain_internal_test.gopackage upstream

import (
	"fmt"
	"io"
	"net"
	"sync/atomic"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestUpstream_plainDNS(t *testing.T) {
	srv := startDNSServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		resp := respondToTestMessage(req)

		err := w.WriteMsg(resp)

		pt := testutil.PanicT{}
		require.NoError(pt, err)
	})
	testutil.CleanupAndRequireSuccess(t, srv.Close)

	addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
	u, err := AddressToUpstream(addr, &Options{
		Logger: slogutil.NewDiscardLogger(),
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	for range 10 {
		checkUpstream(t, u, addr)
	}
}

func TestUpstream_plainDNS_badID(t *testing.T) {
	req := createTestMessage()
	badIDResp := respondToTestMessage(req)
	badIDResp.Id++

	srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) {
		require.NoError(testutil.PanicT{}, w.WriteMsg(badIDResp))
	})
	testutil.CleanupAndRequireSuccess(t, srv.Close)

	addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
	u, err := AddressToUpstream(addr, &Options{
		Logger: slogutil.NewDiscardLogger(),
		// Use a shorter timeout to speed up the test.
		Timeout: 100 * time.Millisecond,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	resp, err := u.Exchange(req)

	var netErr net.Error
	require.ErrorAs(t, err, &netErr)

	assert.True(t, netErr.Timeout())
	assert.Nil(t, resp)
}

func TestUpstream_plainDNS_fallbackToTCP(t *testing.T) {
	req := createTestMessage()
	goodResp := respondToTestMessage(req)

	truncResp := goodResp.Copy()
	truncResp.Truncated = true

	badQNameResp := goodResp.Copy()
	badQNameResp.Question[0].Name = "bad." + req.Question[0].Name

	badQTypeResp := goodResp.Copy()
	badQTypeResp.Question[0].Qtype = dns.TypeCNAME

	testCases := []struct {
		udpResp *dns.Msg
		name    string
		wantUDP int
		wantTCP int
	}{{
		udpResp: goodResp,
		name:    "all_right",
		wantUDP: 1,
		wantTCP: 0,
	}, {
		udpResp: truncResp,
		name:    "truncated_response",
		wantUDP: 1,
		wantTCP: 1,
	}, {
		udpResp: badQNameResp,
		name:    "bad_qname",
		wantUDP: 1,
		wantTCP: 1,
	}, {
		udpResp: badQTypeResp,
		name:    "bad_qtype",
		wantUDP: 1,
		wantTCP: 1,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			var udpReqNum, tcpReqNum atomic.Uint32
			srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) {
				var resp *dns.Msg
				if w.RemoteAddr().Network() == networkUDP {
					udpReqNum.Add(1)
					resp = tc.udpResp
				} else {
					tcpReqNum.Add(1)
					resp = goodResp
				}

				require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
			})
			testutil.CleanupAndRequireSuccess(t, srv.Close)

			addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
			u, err := AddressToUpstream(addr, &Options{
				Logger: slogutil.NewDiscardLogger(),
				// Use a shorter timeout to speed up the test.
				Timeout: 100 * time.Millisecond,
			})
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			resp, err := u.Exchange(req)
			require.NoError(t, err)
			requireResponse(t, req, resp)

			assert.Equal(t, tc.wantUDP, int(udpReqNum.Load()))
			assert.Equal(t, tc.wantTCP, int(tcpReqNum.Load()))
		})
	}
}

// testDNSServer is a simple DNS server that can be used in unit-tests.
type testDNSServer struct {
	udpListener net.PacketConn
	tcpListener net.Listener
	udpSrv      *dns.Server
	tcpSrv      *dns.Server
	port        int
}

// type check
var _ io.Closer = (*testDNSServer)(nil)

// startDNSServer a test DNS server.
func startDNSServer(t *testing.T, handler dns.HandlerFunc) (s *testDNSServer) {
	t.Helper()

	s = &testDNSServer{}

	udpListener, err := net.ListenPacket("udp", "127.0.0.1:0")
	require.NoError(t, err)

	s.port = udpListener.LocalAddr().(*net.UDPAddr).Port
	s.udpListener = udpListener

	s.tcpListener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", s.port))
	require.NoError(t, err)

	s.udpSrv = &dns.Server{
		PacketConn: s.udpListener,
		Handler:    handler,
	}

	s.tcpSrv = &dns.Server{
		Listener: s.tcpListener,
		Handler:  handler,
	}

	go func() {
		pt := testutil.PanicT{}
		require.NoError(pt, s.udpSrv.ActivateAndServe())
	}()

	go func() {
		pt := testutil.PanicT{}
		require.NoError(pt, s.tcpSrv.ActivateAndServe())
	}()

	return s
}

// Close implements the io.Closer interface for *testDNSServer.
func (s *testDNSServer) Close() (err error) {
	udpErr := s.udpSrv.Shutdown()
	tcpErr := s.tcpSrv.Shutdown()

	return errors.WithDeferred(udpErr, tcpErr)
}
070701000000A4000081A4000000000000000000000001679A649F000023D5000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/resolver.gopackage upstream

import (
	"context"
	"fmt"
	"math"
	"net/netip"
	"net/url"
	"slices"
	"strings"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/miekg/dns"
)

// Resolver resolves the hostnames to IP addresses.  Note, that [net.Resolver]
// from standard library also implements this interface.
type Resolver = bootstrap.Resolver

// StaticResolver is a resolver which always responds with an underlying slice
// of IP addresses.
type StaticResolver = bootstrap.StaticResolver

// ParallelResolver is a slice of resolvers that are queried concurrently until
// the first successful response is returned, as opposed to all resolvers being
// queried in order in [ConsequentResolver].
type ParallelResolver = bootstrap.ParallelResolver

// ConsequentResolver is a slice of resolvers that are queried in order until
// the first successful non-empty response, as opposed to just successful
// response requirement in [ParallelResolver].
type ConsequentResolver = bootstrap.ConsequentResolver

// UpstreamResolver is a wrapper around Upstream that implements the
// [bootstrap.Resolver] interface.
type UpstreamResolver struct {
	// Upstream is used for lookups.  It must not be nil.
	Upstream
}

// NewUpstreamResolver creates an upstream that can be used as bootstrap
// [Resolver].  resolverAddress format is the same as in the
// [AddressToUpstream].  If the upstream can't be used as a bootstrap, the
// returned error will have the underlying type of [NotBootstrapError], and r
// itself will be fully usable.  Closing r.Upstream is caller's responsibility.
func NewUpstreamResolver(resolverAddress string, opts *Options) (r *UpstreamResolver, err error) {
	upsOpts := &Options{}

	// TODO(ameshkov):  Aren't other options needed here?
	if opts != nil {
		upsOpts.Timeout = opts.Timeout
		upsOpts.VerifyServerCertificate = opts.VerifyServerCertificate
		upsOpts.PreferIPv6 = opts.PreferIPv6
		upsOpts.Logger = opts.Logger
	}

	ups, err := AddressToUpstream(resolverAddress, upsOpts)
	if err != nil {
		err = fmt.Errorf("upstream bootstrap: creating upstream: %w", err)

		return nil, err
	}

	return &UpstreamResolver{Upstream: ups}, validateBootstrap(ups)
}

// NotBootstrapError is returned by [AddressToUpstream] when the parsed upstream
// can't be used as a bootstrap and wraps the actual reason.
type NotBootstrapError struct {
	// err is the actual reason why the upstream can't be used as a bootstrap.
	err error
}

// type check
var _ error = NotBootstrapError{}

// Error implements the [error] interface for NotBootstrapError.
func (e NotBootstrapError) Error() (msg string) {
	return fmt.Sprintf("not a bootstrap: %s", e.err)
}

// type check
var _ errors.Wrapper = NotBootstrapError{}

// Unwrap implements the [errors.Wrapper] interface.
func (e NotBootstrapError) Unwrap() (reason error) {
	return e.err
}

// validateBootstrap returns an error if u can't be used as a bootstrap.
func validateBootstrap(u Upstream) (err error) {
	var upsURL *url.URL
	switch u := u.(type) {
	case *dnsCrypt:
		return nil
	case *plainDNS:
		upsURL = u.addr
	case *dnsOverTLS:
		upsURL = u.addr
	case *dnsOverHTTPS:
		upsURL = u.addr
	case *dnsOverQUIC:
		upsURL = u.addr
	default:
		return fmt.Errorf("unknown upstream type: %T", u)
	}

	// Make sure the upstream doesn't need a bootstrap.
	_, err = netip.ParseAddr(upsURL.Hostname())
	if err != nil {
		return NotBootstrapError{err: err}
	}

	return nil
}

// type check
var _ Resolver = &UpstreamResolver{}

// LookupNetIP implements the [Resolver] interface for *UpstreamResolver.  It
// doesn't consider the TTL of the DNS records.
//
// TODO(e.burkov):  Investigate why the empty slice is returned instead of nil.
func (r *UpstreamResolver) LookupNetIP(
	ctx context.Context,
	network bootstrap.Network,
	host string,
) (ips []netip.Addr, err error) {
	if host == "" {
		return nil, nil
	}

	host = dns.Fqdn(strings.ToLower(host))

	res, err := r.lookupNetIP(ctx, network, host)
	if err != nil {
		return []netip.Addr{}, err
	}

	return res.addrs, err
}

// ipResult reflects a single A/AAAA record from the DNS response.  It's used
// to cache the results of lookups.
type ipResult struct {
	expire time.Time
	addrs  []netip.Addr
}

// lookupNetIP performs a DNS lookup of host and returns the result.  network
// must be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6], or
// [bootstrap.NetworkIP].  host must be in a lower-case FQDN form.
//
// TODO(e.burkov):  Use context.
func (r *UpstreamResolver) lookupNetIP(
	_ context.Context,
	network bootstrap.Network,
	host string,
) (result *ipResult, err error) {
	switch network {
	case bootstrap.NetworkIP4, bootstrap.NetworkIP6:
		return r.request(host, network)
	case bootstrap.NetworkIP:
		// Go on.
	default:
		return result, fmt.Errorf("unsupported network %s", network)
	}

	resCh := make(chan any, 2)
	go r.resolveAsync(resCh, host, bootstrap.NetworkIP4)
	go r.resolveAsync(resCh, host, bootstrap.NetworkIP6)

	var errs []error
	result = &ipResult{}

	for range 2 {
		switch res := <-resCh; res := res.(type) {
		case error:
			errs = append(errs, res)
		case *ipResult:
			if result.expire.Equal(time.Time{}) || res.expire.Before(result.expire) {
				result.expire = res.expire
			}
			result.addrs = append(result.addrs, res.addrs...)
		}
	}

	return result, errors.Join(errs...)
}

// request performs a single DNS lookup of host and returns all the valid
// addresses from the answer section of the response.  network must be either
// [bootstrap.NetworkIP4], or [bootstrap.NetworkIP6].  host must be in a
// lower-case FQDN form.
//
// TODO(e.burkov):  Consider NS and Extra sections when setting TTL.  Check out
// what RFCs say about it.
func (r *UpstreamResolver) request(host string, n bootstrap.Network) (res *ipResult, err error) {
	var qtype uint16
	switch n {
	case bootstrap.NetworkIP4:
		qtype = dns.TypeA
	case bootstrap.NetworkIP6:
		qtype = dns.TypeAAAA
	default:
		panic(fmt.Sprintf("unsupported network %q", n))
	}

	req := &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id:               dns.Id(),
			RecursionDesired: true,
		},
		Question: []dns.Question{{
			Name:   host,
			Qtype:  qtype,
			Qclass: dns.ClassINET,
		}},
	}

	// As per [Upstream.Exchange] documentation, the response is always returned
	// if no error occurred.
	resp, err := r.Exchange(req)
	if err != nil {
		return res, err
	}

	res = &ipResult{
		expire: time.Now(),
		addrs:  make([]netip.Addr, 0, len(resp.Answer)),
	}
	var minTTL uint32 = math.MaxUint32

	for _, rr := range resp.Answer {
		ip := proxyutil.IPFromRR(rr)
		if !ip.IsValid() {
			continue
		}

		res.addrs = append(res.addrs, ip)
		minTTL = min(minTTL, rr.Header().Ttl)
	}
	res.expire = res.expire.Add(time.Duration(minTTL) * time.Second)

	return res, nil
}

// resolveAsync performs a single DNS lookup and sends the result to ch.  It's
// intended to be used as a goroutine.
func (r *UpstreamResolver) resolveAsync(resCh chan<- any, host, network string) {
	res, err := r.request(host, network)
	if err != nil {
		resCh <- err
	} else {
		resCh <- res
	}
}

// CachingResolver is a [Resolver] that caches the results of lookups.  It's
// required to be created with [NewCachingResolver].
type CachingResolver struct {
	// resolver is the underlying resolver to use for lookups.
	resolver *UpstreamResolver

	// mu protects cache and it's elements.
	mu *sync.RWMutex

	// cache is the set of resolved hostnames mapped to cached addresses.
	//
	// TODO(e.burkov):  Use expiration cache.
	cache map[string]*ipResult
}

// NewCachingResolver creates a new caching resolver that uses r for lookups.
func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) {
	return &CachingResolver{
		resolver: r,
		mu:       &sync.RWMutex{},
		cache:    map[string]*ipResult{},
	}
}

// type check
var _ Resolver = (*CachingResolver)(nil)

// LookupNetIP implements the [Resolver] interface for *CachingResolver.
//
// TODO(e.burkov):  It may appear that several concurrent lookup results rewrite
// each other in the cache.
func (r *CachingResolver) LookupNetIP(
	ctx context.Context,
	network bootstrap.Network,
	host string,
) (addrs []netip.Addr, err error) {
	now := time.Now()
	host = dns.Fqdn(strings.ToLower(host))

	addrs = r.findCached(host, now)
	if addrs != nil {
		return slices.Clone(addrs), nil
	}

	res, err := r.resolver.lookupNetIP(ctx, network, host)
	if err != nil {
		return []netip.Addr{}, err
	}

	r.setCached(host, res)

	return slices.Clone(res.addrs), nil
}

// findCached returns the cached addresses for host if it's not expired yet, and
// the corresponding cached result, if any.  It's safe for concurrent use.
func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip.Addr) {
	r.mu.RLock()
	defer r.mu.RUnlock()

	res, ok := r.cache[host]
	if !ok || res.expire.Before(now) {
		return nil
	}

	return res.addrs
}

// setCached sets the result into the address cache for host.  It's safe for
// concurrent use.
func (r *CachingResolver) setCached(host string, res *ipResult) {
	r.mu.Lock()
	defer r.mu.Unlock()

	r.cache[host] = res
}
070701000000A5000081A4000000000000000000000001679A649F000009CF000000000000000000000000000000000000003300000000dnsproxy-0.75.0/upstream/resolver_internal_test.gopackage upstream

import (
	"context"
	"net/netip"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestCachingResolver_staleness(t *testing.T) {
	ip4 := netip.MustParseAddr("1.2.3.4")
	ip6 := netip.MustParseAddr("2001:db8::1")

	const (
		smallTTL = 10 * time.Second
		largeTTL = 1000 * time.Second

		fqdn = "test.fully.qualified.name."
	)

	onExchange := func(req *dns.Msg) (resp *dns.Msg, err error) {
		resp = (&dns.Msg{}).SetReply(req)

		hdr := dns.RR_Header{
			Name:   req.Question[0].Name,
			Rrtype: req.Question[0].Qtype,
			Class:  dns.ClassINET,
		}
		var rr dns.RR
		switch q := req.Question[0]; q.Qtype {
		case dns.TypeA:
			hdr.Ttl = uint32(smallTTL.Seconds())
			rr = &dns.A{Hdr: hdr, A: ip4.AsSlice()}
		case dns.TypeAAAA:
			hdr.Ttl = uint32(largeTTL.Seconds())
			rr = &dns.AAAA{Hdr: hdr, AAAA: ip6.AsSlice()}
		default:
			require.Contains(testutil.PanicT{}, []uint16{dns.TypeA, dns.TypeAAAA}, q.Qtype)
		}
		resp.Answer = append(resp.Answer, rr)

		return resp, nil
	}

	ups := &dnsproxytest.FakeUpstream{
		OnAddress:  func() (_ string) { panic("not implemented") },
		OnClose:    func() (_ error) { panic("not implemented") },
		OnExchange: onExchange,
	}

	r := NewCachingResolver(&UpstreamResolver{Upstream: ups})

	require.True(t, t.Run("resolve", func(t *testing.T) {
		testCases := []struct {
			name    string
			network bootstrap.Network
			want    []netip.Addr
		}{{
			name:    "ip4",
			network: bootstrap.NetworkIP4,
			want:    []netip.Addr{ip4},
		}, {
			name:    "ip6",
			network: bootstrap.NetworkIP6,
			want:    []netip.Addr{ip6},
		}, {
			name:    "both",
			network: bootstrap.NetworkIP,
			want:    []netip.Addr{ip4, ip6},
		}}

		for _, tc := range testCases {
			t.Run(tc.name, func(t *testing.T) {
				if tc.name != "both" {
					t.Skip(`TODO(e.burkov):  Bootstrap now only uses "ip" network, see TODO there.`)
				}

				res, err := r.LookupNetIP(context.Background(), tc.network, fqdn)
				require.NoError(t, err)

				assert.ElementsMatch(t, tc.want, res)
			})
		}
	}))

	t.Run("staleness", func(t *testing.T) {
		now := time.Now()
		cached := r.findCached(fqdn, now)
		require.ElementsMatch(t, []netip.Addr{ip4, ip6}, cached)

		cached = r.findCached(fqdn, now.Add(smallTTL+time.Second))
		require.Empty(t, cached)
	})
}
070701000000A6000081A4000000000000000000000001679A649F00000C4C000000000000000000000000000000000000002A00000000dnsproxy-0.75.0/upstream/resolver_test.gopackage upstream_test

import (
	"context"
	"net/netip"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestNewUpstreamResolver(t *testing.T) {
	ups := &dnsproxytest.FakeUpstream{
		OnAddress: func() (_ string) { panic("not implemented") },
		OnClose:   func() (_ error) { panic("not implemented") },
		OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
			resp = (&dns.Msg{}).SetReply(req)
			resp.Answer = []dns.RR{&dns.A{
				Hdr: dns.RR_Header{
					Name:   req.Question[0].Name,
					Rrtype: dns.TypeA,
					Class:  dns.ClassINET,
					Ttl:    60,
				},
				A: netip.MustParseAddr("1.2.3.4").AsSlice(),
			}}

			return resp, nil
		},
	}

	r := &upstream.UpstreamResolver{Upstream: ups}

	ipAddrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com")
	require.NoError(t, err)

	assert.NotEmpty(t, ipAddrs)
}

func TestNewUpstreamResolver_validity(t *testing.T) {
	t.Parallel()

	withTimeoutOpt := &upstream.Options{
		Logger:  slogutil.NewDiscardLogger(),
		Timeout: 3 * time.Second,
	}

	testCases := []struct {
		name       string
		addr       string
		wantErrMsg string
	}{{
		name:       "udp",
		addr:       "1.1.1.1:53",
		wantErrMsg: "",
	}, {
		name:       "dot",
		addr:       "tls://1.1.1.1",
		wantErrMsg: "",
	}, {
		name:       "doh",
		addr:       "https://1.1.1.1/dns-query",
		wantErrMsg: "",
	}, {
		name:       "sdns",
		addr:       "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
		wantErrMsg: "",
	}, {
		name:       "tcp",
		addr:       "tcp://9.9.9.9",
		wantErrMsg: "",
	}, {
		name: "invalid_tls",
		addr: "tls://dns.adguard.com",
		wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` +
			`unexpected character (at "dns.adguard.com")`,
	}, {
		name: "invalid_https",
		addr: "https://dns.adguard.com/dns-query",
		wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` +
			`unexpected character (at "dns.adguard.com")`,
	}, {
		name: "invalid_tcp",
		addr: "tcp://dns.adguard.com",
		wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` +
			`unexpected character (at "dns.adguard.com")`,
	}, {
		name: "invalid_no_scheme",
		addr: "dns.adguard.com",
		wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` +
			`unexpected character (at "dns.adguard.com")`,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			r, err := upstream.NewUpstreamResolver(tc.addr, withTimeoutOpt)
			if tc.wantErrMsg != "" {
				assert.Equal(t, tc.wantErrMsg, err.Error())
				if nberr := (&upstream.NotBootstrapError{}); errors.As(err, &nberr) {
					assert.NotNil(t, r)
				}

				return
			}

			require.NoError(t, err)

			addrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com")
			require.NoError(t, err)

			assert.NotEmpty(t, addrs)
		})
	}
}
070701000000A7000081A4000000000000000000000001679A649F00003215000000000000000000000000000000000000002500000000dnsproxy-0.75.0/upstream/upstream.go// Package upstream implements DNS clients for all known DNS encryption
// protocols.
package upstream

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"log/slog"
	"net"
	"net/netip"
	"net/url"
	"os"
	"strconv"
	"strings"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/ameshkov/dnsstamps"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/logging"
)

// Upstream is an interface for a DNS resolver.  All the methods must be safe
// for concurrent use.
type Upstream interface {
	// Exchange sends req to this upstream and returns the response that has
	// been received or an error if something went wrong.  The implementations
	// must not modify req as well as the caller must not modify it until the
	// method returns.  It shouldn't be called after closing.
	Exchange(req *dns.Msg) (resp *dns.Msg, err error)

	// Address returns the human-readable address of the upstream DNS resolver.
	// It may differ from what was passed to [AddressToUpstream].
	Address() (addr string)

	// Closer used to close the upstreams properly.
	io.Closer
}

// QUICTraceFunc is a function that returns a [logging.ConnectionTracer]
// specific for a given role and connection ID.
type QUICTraceFunc func(
	ctx context.Context,
	role logging.Perspective,
	connID quic.ConnectionID,
) (tracer *logging.ConnectionTracer)

// Options for AddressToUpstream func.  With these options we can configure the
// upstream properties.
type Options struct {
	// Logger is used for logging during parsing and upstream exchange.  If nil,
	// [slog.Default] is used.
	Logger *slog.Logger

	// VerifyServerCertificate is used to set the VerifyPeerCertificate property
	// of the *tls.Config for DNS-over-HTTPS, DNS-over-QUIC, and DNS-over-TLS.
	VerifyServerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error

	// VerifyConnection is used to set the VerifyConnection property
	// of the *tls.Config for DNS-over-HTTPS, DNS-over-QUIC, and DNS-over-TLS.
	VerifyConnection func(state tls.ConnectionState) error

	// VerifyDNSCryptCertificate is the callback the DNSCrypt server certificate
	// will be passed to.  It's called in dnsCrypt.exchangeDNSCrypt.
	// Upstream.Exchange method returns any error caused by it.
	VerifyDNSCryptCertificate func(cert *dnscrypt.Cert) error

	// QUICTracer is an optional callback that allows tracing every QUIC
	// connection and logging every packet that goes through.
	QUICTracer QUICTraceFunc

	// RootCAs is the CertPool that must be used by all upstreams.  Redefining
	// RootCAs makes sense on iOS to overcome the 15MB memory limit of the
	// NEPacketTunnelProvider.
	RootCAs *x509.CertPool

	// CipherSuites is a custom list of TLSv1.2 ciphers.
	CipherSuites []uint16

	// Bootstrap is used to resolve upstreams' hostnames.  If nil, the
	// [net.DefaultResolver] will be used.
	Bootstrap Resolver

	// HTTPVersions is a list of HTTP versions that should be supported by the
	// DNS-over-HTTPS client.  If not set, HTTP/1.1 and HTTP/2 will be used.
	HTTPVersions []HTTPVersion

	// Timeout is the default upstream timeout.  It's also used as a timeout for
	// bootstrap DNS requests.  Zero value disables the timeout.
	Timeout time.Duration

	// InsecureSkipVerify disables verifying the server's certificate.
	InsecureSkipVerify bool

	// PreferIPv6 tells the bootstrapper to prefer IPv6 addresses for an
	// upstream.
	PreferIPv6 bool
}

// Clone copies o to a new struct.  Note, that this is not a deep clone.
func (o *Options) Clone() (clone *Options) {
	return &Options{
		Bootstrap:                 o.Bootstrap,
		Timeout:                   o.Timeout,
		HTTPVersions:              o.HTTPVersions,
		VerifyServerCertificate:   o.VerifyServerCertificate,
		VerifyConnection:          o.VerifyConnection,
		VerifyDNSCryptCertificate: o.VerifyDNSCryptCertificate,
		InsecureSkipVerify:        o.InsecureSkipVerify,
		PreferIPv6:                o.PreferIPv6,
		QUICTracer:                o.QUICTracer,
		RootCAs:                   o.RootCAs,
		CipherSuites:              o.CipherSuites,
		Logger:                    o.Logger,
	}
}

// HTTPVersion is an enumeration of the HTTP versions that we support.  Values
// that we use in this enumeration are also used as ALPN values.
type HTTPVersion string

const (
	// HTTPVersion11 is HTTP/1.1.
	HTTPVersion11 HTTPVersion = "http/1.1"
	// HTTPVersion2 is HTTP/2.
	HTTPVersion2 HTTPVersion = "h2"
	// HTTPVersion3 is HTTP/3.
	HTTPVersion3 HTTPVersion = "h3"
)

// DefaultHTTPVersions is the list of HTTPVersion that we use by default in
// the DNS-over-HTTPS client.
var DefaultHTTPVersions = []HTTPVersion{HTTPVersion11, HTTPVersion2}

const (
	// defaultPortPlain is the default port for plain DNS.
	defaultPortPlain = 53

	// defaultPortDoH is the default port for DNS-over-HTTPS.
	defaultPortDoH = 443

	// defaultPortDoT is the default port for DNS-over-TLS.
	defaultPortDoT = 853

	// defaultPortDoQ is the default port for DNS-over-QUIC.  Prior to version
	// -10 of the draft experiments were directed to use ports 8853, 784.
	//
	// See https://www.rfc-editor.org/rfc/rfc9250.html#name-port-selection.
	defaultPortDoQ = 853
)

// AddressToUpstream converts addr to an Upstream using the specified options.
// addr can be either a URL, or a plain address, either a domain name or an IP.
//
//   - 1.2.3.4 or 1.2.3.4:4321 for plain DNS using IP address;
//   - udp://5.3.5.3:53 or 5.3.5.3:53 for plain DNS using IP address;
//   - udp://name.server:53 or name.server:53 for plain DNS using domain name;
//   - tcp://5.3.5.3:53 for plain DNS-over-TCP using IP address;
//   - tcp://name.server:53 for plain DNS-over-TCP using domain name;
//   - tls://5.3.5.3:853 for DNS-over-TLS using IP address;
//   - tls://name.server:853 for DNS-over-TLS using domain name;
//   - https://5.3.5.3:443/dns-query for DNS-over-HTTPS using IP address;
//   - https://name.server:443/dns-query for DNS-over-HTTPS using domain name;
//   - quic://5.3.5.3:853 for DNS-over-QUIC using IP address;
//   - quic://name.server:853 for DNS-over-QUIC using domain name;
//   - h3://dns.google for DNS-over-HTTPS that only works with HTTP/3;
//   - sdns://... for DNS stamp, see https://dnscrypt.info/stamps-specifications.
//
// If addr doesn't have port specified, the default port of the appropriate
// protocol will be used.
//
// opts are applied to the u and shouldn't be modified afterwards, nil value is
// valid.
//
// TODO(e.burkov):  Clone opts?
func AddressToUpstream(addr string, opts *Options) (u Upstream, err error) {
	if opts == nil {
		opts = &Options{}
	}

	if opts.Logger == nil {
		opts.Logger = slog.Default()
	}

	var uu *url.URL
	if strings.Contains(addr, "://") {
		uu, err = url.Parse(addr)
		if err != nil {
			return nil, fmt.Errorf("failed to parse %s: %w", addr, err)
		}
	} else {
		uu = &url.URL{
			Scheme: "udp",
			Host:   addr,
		}
	}

	err = validateUpstreamURL(uu)
	if err != nil {
		// Don't wrap the error, because it's informative enough as is.
		return nil, err
	}

	return urlToUpstream(uu, opts)
}

// validateUpstreamURL returns an error if the upstream URL is not valid.
func validateUpstreamURL(u *url.URL) (err error) {
	if u.Scheme == "sdns" {
		return nil
	}

	host := u.Host
	// TODO(s.chzhen):  Consider using [netutil.SplitHostPort].
	h, port, splitErr := net.SplitHostPort(host)
	if splitErr == nil {
		// Validate port.
		_, err = strconv.ParseUint(port, 10, 16)
		if err != nil {
			return fmt.Errorf("invalid port %s: %w", port, err)
		}

		host = h
	}

	// minEnclosedIPv6Len is the minimum length of an IP address enclosed in
	// square brackets.
	const minEnclosedIPv6Len = len("[::]")

	possibleIP := host
	if l := len(host); l >= minEnclosedIPv6Len && host[0] == '[' && host[l-1] == ']' {
		// Might be an IPv6 address enclosed in square brackets with no port.
		//
		// See https://github.com/AdguardTeam/dnsproxy/issues/379.
		possibleIP = host[1 : l-1]
	}
	if netutil.IsValidIPString(possibleIP) {
		return nil
	}

	err = netutil.ValidateDomainName(host)
	if err != nil {
		return fmt.Errorf("invalid address %s: %w", host, err)
	}

	return nil
}

// urlToUpstream converts uu to an Upstream using opts.
func urlToUpstream(uu *url.URL, opts *Options) (u Upstream, err error) {
	switch sch := uu.Scheme; sch {
	case "sdns":
		return parseStamp(uu, opts)
	case "udp", "tcp":
		return newPlain(uu, opts)
	case "quic":
		return newDoQ(uu, opts)
	case "tls":
		return newDoT(uu, opts)
	case "h3", "https":
		return newDoH(uu, opts)
	default:
		return nil, fmt.Errorf("unsupported url scheme: %s", sch)
	}
}

// parseStamp converts a DNS stamp to an Upstream.
func parseStamp(upsURL *url.URL, opts *Options) (u Upstream, err error) {
	stamp, err := dnsstamps.NewServerStampFromString(upsURL.String())
	if err != nil {
		return nil, fmt.Errorf("failed to parse %s: %w", upsURL, err)
	}

	// TODO(e.burkov):  Port?
	if stamp.ServerAddrStr != "" {
		host, _, sErr := netutil.SplitHostPort(stamp.ServerAddrStr)
		if sErr != nil {
			host = stamp.ServerAddrStr
		}

		var ip netip.Addr
		ip, err = netip.ParseAddr(host)
		if err != nil {
			return nil, fmt.Errorf("invalid server stamp address %s", stamp.ServerAddrStr)
		}

		opts.Bootstrap = StaticResolver{ip}
	}

	switch stamp.Proto {
	case dnsstamps.StampProtoTypePlain:
		return newPlain(&url.URL{Scheme: "udp", Host: stamp.ServerAddrStr}, opts)
	case dnsstamps.StampProtoTypeDNSCrypt:
		return newDNSCrypt(upsURL, opts), nil
	case dnsstamps.StampProtoTypeDoH:
		return newDoH(&url.URL{Scheme: "https", Host: stamp.ProviderName, Path: stamp.Path}, opts)
	case dnsstamps.StampProtoTypeDoQ:
		return newDoQ(&url.URL{Scheme: "quic", Host: stamp.ProviderName, Path: stamp.Path}, opts)
	case dnsstamps.StampProtoTypeTLS:
		return newDoT(&url.URL{Scheme: "tls", Host: stamp.ProviderName}, opts)
	default:
		return nil, fmt.Errorf("unsupported stamp protocol %s", &stamp.Proto)
	}
}

// addPort appends port to u if it's absent.
func addPort(u *url.URL, port uint16) {
	if u != nil {
		_, _, err := net.SplitHostPort(u.Host)
		if err != nil {
			u.Host = netutil.JoinHostPort(u.Host, port)

			return
		}
	}
}

// logBegin logs the start of DNS request resolution.  It should be called right
// before dialing the connection to the upstream.  n is the [network] that will
// be used to send the request.
func logBegin(l *slog.Logger, addr string, n network, req *dns.Msg) {
	var qtype dns.Type
	var qname string
	if len(req.Question) != 0 {
		qtype = dns.Type(req.Question[0].Qtype)
		qname = req.Question[0].Name
	}

	l.Debug("sending request", "addr", addr, "proto", n, "qtype", qtype, "qname", qname)
}

// logFinish logs the end of DNS request resolution.  It should be called right
// after receiving the response from the upstream or the failing action.  n is
// the [network] that was used to send the request.
func logFinish(l *slog.Logger, addr string, n network, err error) {
	lvl := slog.LevelDebug
	status := "ok"

	if err != nil {
		status = err.Error()
		if isTimeout(err) {
			// Notify user about the timeout.
			lvl = slog.LevelError
		}
	}

	l.Log(context.TODO(), lvl, "response received", "addr", addr, "proto", n, "status", status)
}

// isTimeout returns true if err is a timeout error.
//
// TODO(e.burkov):  Move to golibs.
func isTimeout(err error) (ok bool) {
	var netErr net.Error
	switch {
	case
		errors.Is(err, context.Canceled),
		errors.Is(err, context.DeadlineExceeded),
		errors.Is(err, os.ErrDeadlineExceeded):
		return true
	case errors.As(err, &netErr):
		return netErr.Timeout()
	default:
		return false
	}
}

// DialerInitializer returns the handler that it creates.
type DialerInitializer func() (handler bootstrap.DialHandler, err error)

// newDialerInitializer creates an initializer of the dialer that will dial the
// addresses resolved from u using opts.
func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer) {
	var l *slog.Logger
	if opts.Logger != nil {
		l = opts.Logger.With(slogutil.KeyPrefix, "bootstrap")
	} else {
		l = slog.Default()
	}

	// TODO(e.burkov):  Add netutil.IsValidIPPortString.
	if _, err := netip.ParseAddrPort(u.Host); err == nil {
		// Don't resolve the address of the server since it's already an IP.
		handler := bootstrap.NewDialContext(opts.Timeout, l, u.Host)

		return func() (h bootstrap.DialHandler, dialerErr error) {
			return handler, nil
		}
	}

	boot := opts.Bootstrap
	if boot == nil {
		// Use the default resolver for bootstrapping.
		boot = net.DefaultResolver
	}

	return func() (h bootstrap.DialHandler, err error) {
		return bootstrap.ResolveDialContext(u, opts.Timeout, boot, opts.PreferIPv6, l)
	}
}
070701000000A8000081A4000000000000000000000001679A649F00004DA1000000000000000000000000000000000000003300000000dnsproxy-0.75.0/upstream/upstream_internal_test.gopackage upstream

import (
	"crypto/ecdsa"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"fmt"
	"math/big"
	"net"
	"net/netip"
	"net/url"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/ameshkov/dnsstamps"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// TODO(ameshkov): Make tests here not depend on external servers.

// TODO(d.kolyshev): Remove this after migrating dnscrypt to slog.
func TestMain(m *testing.M) {
	testutil.DiscardLogOutput(m)
}

func TestUpstream_bootstrapTimeout(t *testing.T) {
	t.Parallel()

	const (
		timeout = 100 * time.Millisecond
		count   = 10
	)

	// Test listener that never accepts connections to emulate faulty bootstrap.
	udpListener, err := net.ListenPacket("udp", "127.0.0.1:0")
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, udpListener.Close)

	rslv, err := NewUpstreamResolver(udpListener.LocalAddr().String(), &Options{
		Logger:  slogutil.NewDiscardLogger(),
		Timeout: timeout,
	})
	require.NoError(t, err)

	// Create an upstream that uses this faulty bootstrap.
	u, err := AddressToUpstream("tls://random-domain-name", &Options{
		Logger:    slogutil.NewDiscardLogger(),
		Bootstrap: NewCachingResolver(rslv),
		Timeout:   timeout,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	ch := make(chan int, count)
	abort := make(chan string, 1)
	for i := range count {
		go func(idx int) {
			t.Logf("Start %d", idx)
			req := createTestMessage()

			start := time.Now()
			_, rErr := u.Exchange(req)
			elapsed := time.Since(start)

			if rErr == nil {
				// Must not happen since bootstrap server cannot work.
				abort <- fmt.Sprintf("the upstream must have timed out: %v", rErr)
			}

			// Check that the test didn't take too much time compared to the
			// configured timeout.  The actual elapsed time may be higher than
			// the timeout due to the execution environment, 3 is an arbitrarily
			// chosen multiplier to account for that.
			if elapsed > 3*timeout {
				abort <- fmt.Sprintf(
					"exchange took more time than the configured timeout: %s",
					elapsed,
				)
			}
			t.Logf("Finished %d", idx)
			ch <- idx
		}(i)
	}

	for range count {
		select {
		case res := <-ch:
			t.Logf("Got result from %d", res)
		case msg := <-abort:
			t.Fatalf("Aborted from the goroutine: %s", msg)
		case <-time.After(timeout * 10):
			t.Fatalf("No response in time")
		}
	}
}

func TestUpstreams(t *testing.T) {
	t.Parallel()

	const upsTimeout = 10 * time.Second

	l := slogutil.NewDiscardLogger()

	googleRslv, err := NewUpstreamResolver("8.8.8.8:53", &Options{
		Logger:  l,
		Timeout: upsTimeout,
	})
	require.NoError(t, err)
	cloudflareRslv, err := NewUpstreamResolver("1.0.0.1:53", &Options{
		Logger:  l,
		Timeout: upsTimeout,
	})
	require.NoError(t, err)

	googleBoot := NewCachingResolver(googleRslv)
	cloudflareBoot := NewCachingResolver(cloudflareRslv)

	upstreams := []struct {
		bootstrap Resolver
		address   string
	}{{
		bootstrap: googleBoot,
		address:   "8.8.8.8:53",
	}, {
		bootstrap: nil,
		address:   "1.1.1.1",
	}, {
		bootstrap: cloudflareBoot,
		address:   "1.1.1.1",
	}, {
		bootstrap: nil,
		address:   "tcp://1.1.1.1:53",
	}, {
		bootstrap: nil,
		address:   "94.140.14.14:5353",
	}, {
		bootstrap: nil,
		address:   "tls://1.1.1.1",
	}, {
		bootstrap: nil,
		address:   "tls://9.9.9.9:853",
	}, {
		bootstrap: googleBoot,
		address:   "tls://dns.adguard.com",
	}, {
		bootstrap: googleBoot,
		address:   "tls://dns.adguard.com:853",
	}, {
		bootstrap: googleBoot,
		address:   "tls://dns.adguard.com:853",
	}, {
		bootstrap: nil,
		address:   "tls://one.one.one.one",
	}, {
		bootstrap: googleBoot,
		address:   "https://1dot1dot1dot1.cloudflare-dns.com/dns-query",
	}, {
		bootstrap: nil,
		address:   "https://dns.google/dns-query",
	}, {
		bootstrap: nil,
		address:   "https://doh.opendns.com/dns-query",
	}, {
		// AdGuard DNS (DNSCrypt)
		bootstrap: nil,
		address:   "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
	}, {
		// AdGuard Family (DNSCrypt)
		bootstrap: googleBoot,
		address:   "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNTo1NDQzILgxXdexS27jIKRw3C7Wsao5jMnlhvhdRUXWuMm1AFq6ITIuZG5zY3J5cHQuZmFtaWx5Lm5zMS5hZGd1YXJkLmNvbQ",
	}, {
		// Cloudflare DNS (DNS-over-HTTPS)
		bootstrap: googleBoot,
		address:   "sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk",
	}, {
		// Google (Plain)
		bootstrap: nil,
		address:   "sdns://AAcAAAAAAAAABzguOC44Ljg",
	}, {
		// AdGuard DNS (DNS-over-TLS)
		bootstrap: googleBoot,
		address:   "sdns://AwAAAAAAAAAAAAAPZG5zLmFkZ3VhcmQuY29t",
	}, {
		// AdGuard DNS (DNS-over-QUIC)
		bootstrap: googleBoot,
		address:   "sdns://BAcAAAAAAAAAAAAXZG5zLmFkZ3VhcmQtZG5zLmNvbTo3ODQ",
	}, {
		// Cloudflare DNS (DNS-over-HTTPS)
		bootstrap: nil,
		address:   "https://1.1.1.1/dns-query",
	}, {
		// AdGuard DNS (DNS-over-QUIC)
		bootstrap: googleBoot,
		address:   "quic://dns.adguard-dns.com",
	}, {
		// Google DNS (HTTP3)
		bootstrap: nil,
		address:   "h3://dns.google/dns-query",
	}}

	for _, test := range upstreams {
		t.Run(test.address, func(t *testing.T) {
			t.Parallel()

			u, upsErr := AddressToUpstream(
				test.address,
				&Options{Logger: l, Bootstrap: test.bootstrap, Timeout: upsTimeout},
			)
			require.NoErrorf(t, upsErr, "failed to generate upstream from address %s", test.address)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			checkUpstream(t, u, test.address)
		})
	}
}

func TestAddressToUpstream(t *testing.T) {
	cloudflareRslv, err := NewUpstreamResolver("1.1.1.1", nil)
	require.NoError(t, err)

	opt := &Options{
		Logger:    slogutil.NewDiscardLogger(),
		Bootstrap: NewCachingResolver(cloudflareRslv),
	}

	testCases := []struct {
		addr string
		opt  *Options
		want string
	}{{
		addr: "1.1.1.1",
		opt:  nil,
		want: "1.1.1.1:53",
	}, {
		addr: "1.1.1.1:5353",
		opt:  nil,
		want: "1.1.1.1:5353",
	}, {
		addr: "one:5353",
		opt:  nil,
		want: "one:5353",
	}, {
		addr: "one.one.one.one",
		opt:  nil,
		want: "one.one.one.one:53",
	}, {
		addr: "udp://one.one.one.one",
		opt:  nil,
		want: "one.one.one.one:53",
	}, {
		addr: "tcp://one.one.one.one",
		opt:  opt,
		want: "tcp://one.one.one.one:53",
	}, {
		addr: "tls://one.one.one.one",
		opt:  opt,
		want: "tls://one.one.one.one:853",
	}, {
		addr: "https://one.one.one.one",
		opt:  opt,
		want: "https://one.one.one.one:443",
	}, {
		addr: "h3://one.one.one.one",
		opt:  opt,
		want: "https://one.one.one.one:443",
	}, {
		addr: "::ffff:1.1.1.1",
		opt:  nil,
		want: "[::ffff:1.1.1.1]:53",
	}, {
		addr: "https://[2606:4700:4700::1111]/dns-query",
		opt:  nil,
		want: "https://[2606:4700:4700::1111]:443/dns-query",
	}, {
		addr: "https://[2606:4700:4700::1111]:443/dns-query",
		opt:  nil,
		want: "https://[2606:4700:4700::1111]:443/dns-query",
	}}

	for _, tc := range testCases {
		t.Run(tc.addr, func(t *testing.T) {
			u, upsErr := AddressToUpstream(tc.addr, tc.opt)
			require.NoError(t, upsErr)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			assert.Equal(t, tc.want, u.Address())
		})
	}
}

func TestAddressToUpstream_bads(t *testing.T) {
	testCases := []struct {
		addr       string
		wantErrMsg string
	}{{
		addr:       "asdf://1.1.1.1",
		wantErrMsg: "unsupported url scheme: asdf",
	}, {
		addr: "12345.1.1.1:1234567",
		wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
			`value out of range`,
	}, {
		addr: ":1234567",
		wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
			`value out of range`,
	}, {
		addr:       "host:",
		wantErrMsg: `invalid port : strconv.ParseUint: parsing "": invalid syntax`,
	}, {
		addr:       ":53",
		wantErrMsg: `invalid address : bad domain name "": domain name is empty`,
	}, {
		addr: "!!!",
		wantErrMsg: `invalid address !!!: bad domain name "!!!": bad top-level domain name ` +
			`label "!!!": bad top-level domain name label rune '!'`,
	}, {
		addr: "123",
		wantErrMsg: `invalid address 123: bad domain name "123": bad top-level domain name ` +
			`label "123": all octets are numeric`,
	}, {
		addr: "tcp://12345.1.1.1:1234567",
		wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
			`value out of range`,
	}, {
		addr: "tcp://:1234567",
		wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` +
			`value out of range`,
	}, {
		addr:       "tcp://host:",
		wantErrMsg: `invalid port : strconv.ParseUint: parsing "": invalid syntax`,
	}, {
		addr:       "tcp://:53",
		wantErrMsg: `invalid address : bad domain name "": domain name is empty`,
	}, {
		addr: "tcp://!!!",
		wantErrMsg: `invalid address !!!: bad domain name "!!!": bad top-level domain name ` +
			`label "!!!": bad top-level domain name label rune '!'`,
	}, {
		addr: "tcp://123",
		wantErrMsg: `invalid address 123: bad domain name "123": bad top-level domain name ` +
			`label "123": all octets are numeric`,
	}}

	for _, tc := range testCases {
		t.Run(tc.addr, func(t *testing.T) {
			_, err := AddressToUpstream(tc.addr, nil)
			testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
		})
	}
}

func TestUpstreamDoTBootstrap(t *testing.T) {
	t.Parallel()

	upstreams := []struct {
		address   string
		bootstrap string
	}{{
		address:   "tls://one.one.one.one/",
		bootstrap: "tls://1.1.1.1",
	}, {
		address:   "tls://one.one.one.one/",
		bootstrap: "https://1.1.1.1/dns-query",
	}, {
		address: "tls://one.one.one.one/",
		// Cisco OpenDNS
		bootstrap: "sdns://AQAAAAAAAAAADjIwOC42Ny4yMjAuMjIwILc1EUAgbyJdPivYItf9aR6hwzzI1maNDL4Ev6vKQ_t5GzIuZG5zY3J5cHQtY2VydC5vcGVuZG5zLmNvbQ",
	}}

	for _, tc := range upstreams {
		t.Run(tc.address, func(t *testing.T) {
			rslv, err := NewUpstreamResolver(tc.bootstrap, &Options{
				Logger:  slogutil.NewDiscardLogger(),
				Timeout: timeout,
			})
			require.NoError(t, err)

			u, err := AddressToUpstream(tc.address, &Options{
				Logger:    slogutil.NewDiscardLogger(),
				Bootstrap: NewCachingResolver(rslv),
				Timeout:   timeout,
			})
			require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			checkUpstream(t, u, tc.address)
		})
	}
}

// Test for DoH and DoT upstreams with two bootstraps (only one is valid)
func TestUpstreamsInvalidBootstrap(t *testing.T) {
	t.Parallel()

	upstreams := []struct {
		address   string
		bootstrap []string
	}{{
		address:   "tls://dns.adguard.com",
		bootstrap: []string{"1.1.1.1:555", "8.8.8.8:53"},
	}, {
		address:   "tls://dns.adguard.com:853",
		bootstrap: []string{"1.0.0.1", "8.8.8.8:535"},
	}, {
		address:   "https://1dot1dot1dot1.cloudflare-dns.com/dns-query",
		bootstrap: []string{"8.8.8.1", "1.0.0.1"},
	}, {
		address:   "https://doh.opendns.com:443/dns-query",
		bootstrap: []string{"1.2.3.4:79", "8.8.8.8:53"},
	}, {
		// Cloudflare DNS (DoH)
		address:   "sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk",
		bootstrap: []string{"8.8.8.8:53", "8.8.8.1:53"},
	}, {
		// AdGuard DNS (DNS-over-TLS)
		address:   "sdns://AwAAAAAAAAAAAAAPZG5zLmFkZ3VhcmQuY29t",
		bootstrap: []string{"1.2.3.4:55", "8.8.8.8"},
	}}

	l := slogutil.NewDiscardLogger()

	for _, tc := range upstreams {
		t.Run(tc.address, func(t *testing.T) {
			t.Parallel()

			var rslv ConsequentResolver
			for _, b := range tc.bootstrap {
				r, err := NewUpstreamResolver(b, &Options{
					Logger:  l,
					Timeout: timeout,
				})
				require.NoError(t, err)

				rslv = append(rslv, NewCachingResolver(r))
			}

			u, err := AddressToUpstream(tc.address, &Options{
				Logger:    l,
				Bootstrap: rslv,
				Timeout:   timeout,
			})
			require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			checkUpstream(t, u, tc.address)
		})
	}

	t.Run("bad_bootstrap", func(t *testing.T) {
		_, err := NewUpstreamResolver("asdfasdf", nil)
		assert.Error(t, err) // bad bootstrap "asdfasdf"
	})
}

func TestAddressToUpstream_StaticResolver(t *testing.T) {
	t.Parallel()

	h := func(w dns.ResponseWriter, m *dns.Msg) {
		require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(m)))
	}
	dotSrv := startDoTServer(t, h)
	dohSrv := startDoHServer(t, testDoHServerOptions{})
	_, dohPort, err := net.SplitHostPort(dohSrv.addr)
	require.NoError(t, err)

	badResolver := &UpstreamResolver{Upstream: nil}

	dotStamp := (&dnsstamps.ServerStamp{
		ServerAddrStr: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(dotSrv.port)).String(),
		Proto:         dnsstamps.StampProtoTypeTLS,
		ProviderName:  netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(dotSrv.port)).String(),
	}).String()
	dohStamp := (&dnsstamps.ServerStamp{
		ServerAddrStr: dohSrv.addr,
		Proto:         dnsstamps.StampProtoTypeDoH,
		ProviderName:  dohSrv.addr,
		Path:          "/dns-query",
	}).String()

	upstreams := []struct {
		rslv    Resolver
		name    string
		address string
	}{{
		rslv:    StaticResolver{netutil.IPv4Localhost()},
		name:    "dot",
		address: fmt.Sprintf("tls://some.dns.server:%d", dotSrv.port),
	}, {
		rslv:    StaticResolver{netutil.IPv4Localhost()},
		name:    "doh",
		address: fmt.Sprintf("https://some.dns.server:%s/dns-query", dohPort),
	}, {
		rslv:    badResolver,
		name:    "dot_stamp",
		address: dotStamp,
	}, {
		rslv:    badResolver,
		name:    "doh_stamp",
		address: dohStamp,
	}}

	for _, tc := range upstreams {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			opts := &Options{
				Logger:             slogutil.NewDiscardLogger(),
				Bootstrap:          tc.rslv,
				Timeout:            timeout,
				InsecureSkipVerify: true,
			}
			u, uErr := AddressToUpstream(tc.address, opts)
			require.NoError(t, uErr)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			assert.NotPanics(t, func() {
				checkUpstream(t, u, tc.address)
			})
		})
	}
}

func TestAddPort(t *testing.T) {
	testCases := []struct {
		name string
		want string
		host string
		port uint16
	}{{
		name: "empty",
		want: ":0",
		host: "",
		port: 0,
	}, {
		name: "hostname",
		want: "example.org:53",
		host: "example.org",
		port: 53,
	}, {
		name: "ipv4",
		want: "1.2.3.4:1",
		host: "1.2.3.4",
		port: 1,
	}, {
		name: "ipv6",
		want: "[::1]:1",
		host: "::1",
		port: 1,
	}, {
		name: "ipv6_with_brackets",
		want: "[::1]:1",
		host: "[::1]",
		port: 1,
	}, {
		name: "hostname_with_port",
		want: "example.org:54",
		host: "example.org:54",
		port: 53,
	}, {
		name: "ipv4_with_port",
		want: "1.2.3.4:2",
		host: "1.2.3.4:2",
		port: 1,
	}, {
		name: "ipv6_with_brackets_and_port",
		want: "[::1]:2",
		host: "[::1]:2",
		port: 1,
	}}

	for _, tc := range testCases {
		u := &url.URL{
			Host: tc.host,
		}

		t.Run(tc.name, func(t *testing.T) {
			addPort(u, tc.port)
			assert.Equal(t, tc.want, u.Host)
		})
	}
}

// checkUpstream sends a test message to the upstream and checks the result.
func checkUpstream(t *testing.T, u Upstream, addr string) {
	t.Helper()

	req := createTestMessage()
	reply, err := u.Exchange(req)
	require.NoErrorf(t, err, "couldn't talk to upstream %s", addr)

	requireResponse(t, req, reply)
}

// checkRaceCondition runs several goroutines in parallel and each of them calls
// checkUpstream several times.
func checkRaceCondition(u Upstream) {
	wg := sync.WaitGroup{}

	// The number of requests to run in every goroutine.
	reqCount := 10
	// The overall number of goroutines to run.
	goroutinesCount := 3

	makeRequests := func() {
		defer wg.Done()
		for range reqCount {
			req := createTestMessage()
			// Ignore exchange errors here, the point is to check for races.
			_, _ = u.Exchange(req)
		}
	}

	wg.Add(goroutinesCount)
	for range goroutinesCount {
		go makeRequests()
	}

	wg.Wait()
}

// createTestMessage creates a *dns.Msg that we use for tests and that we then
// check with requireResponse.
func createTestMessage() (m *dns.Msg) {
	return createHostTestMessage("google-public-dns-a.google.com")
}

// respondToTestMessage crafts a *dns.Msg response to a message created by
// createTestMessage.
func respondToTestMessage(m *dns.Msg) (resp *dns.Msg) {
	resp = &dns.Msg{}
	resp.SetReply(m)
	resp.Answer = append(resp.Answer, &dns.A{
		A: net.IPv4(8, 8, 8, 8),
		Hdr: dns.RR_Header{
			Name:   "google-public-dns-a.google.com.",
			Rrtype: dns.TypeA,
			Class:  dns.ClassINET,
			Ttl:    100,
		},
	})

	return resp
}

// createHostTestMessage creates a *dns.Msg with A request for the specified
// host name.
func createHostTestMessage(host string) (req *dns.Msg) {
	return &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id:               dns.Id(),
			RecursionDesired: true,
		},
		Question: []dns.Question{{
			Name:   dns.Fqdn(host),
			Qtype:  dns.TypeA,
			Qclass: dns.ClassINET,
		}},
	}
}

// requireResponse validates that the *dns.Msg is a valid response to the
// message created by createTestMessage.
func requireResponse(t require.TestingT, req, reply *dns.Msg) {
	require.NotNil(t, reply)
	require.Lenf(t, reply.Answer, 1, "wrong number of answers: %d", len(reply.Answer))
	require.Equal(t, req.Id, reply.Id)

	a, ok := reply.Answer[0].(*dns.A)
	require.Truef(t, ok, "wrong answer type: %v", reply.Answer[0])

	require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}

// createServerTLSConfig creates a test server TLS configuration. It returns
// a *tls.Config that can be used for both the server and the client and the
// root certificate pem-encoded.
// TODO(ameshkov): start using rootCAs in tests instead of InsecureVerify.
func createServerTLSConfig(
	tb testing.TB,
	tlsServerName string,
) (tlsConfig *tls.Config, rootCAs *x509.CertPool) {
	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
	require.NoError(tb, err)

	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	require.NoError(tb, err)

	notBefore := time.Now()
	notAfter := notBefore.Add(5 * 365 * time.Hour * 24)

	template := x509.Certificate{
		SerialNumber: serialNumber,
		Subject: pkix.Name{
			Organization: []string{"AdGuard Tests"},
		},
		NotBefore: notBefore,
		NotAfter:  notAfter,

		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
		IsCA:                  true,
	}

	ipAddress := net.ParseIP(tlsServerName)
	if ipAddress != nil {
		template.IPAddresses = append(template.IPAddresses, ipAddress)
	} else {
		template.DNSNames = append(template.DNSNames, tlsServerName)
	}

	derBytes, err := x509.CreateCertificate(
		rand.Reader,
		&template,
		&template,
		publicKey(privateKey),
		privateKey,
	)
	require.NoError(tb, err)

	certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
	keyPem := pem.EncodeToMemory(
		&pem.Block{
			Type:  "RSA PRIVATE KEY",
			Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
		},
	)

	cert, err := tls.X509KeyPair(certPem, keyPem)
	require.NoError(tb, err)

	rootCAs = x509.NewCertPool()
	rootCAs.AppendCertsFromPEM(certPem)

	tlsConfig = &tls.Config{
		Certificates: []tls.Certificate{cert},
		ServerName:   tlsServerName,
		RootCAs:      rootCAs,
		MinVersion:   tls.VersionTLS12,
	}

	return tlsConfig, rootCAs
}

// publicKey extracts the public key from the specified private key.
func publicKey(priv any) (pub any) {
	switch k := priv.(type) {
	case *rsa.PrivateKey:
		return &k.PublicKey
	case *ecdsa.PrivateKey:
		return &k.PublicKey
	default:
		return nil
	}
}
07070100000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000B00000000TRAILER!!!1347 blocks
openSUSE Build Service is sponsored by