File dnsproxy-0.55.0.obscpio of Package dnsproxy

07070100000000000081A4000000000000000000000001650C592100000080000000000000000000000000000000000000001D00000000dnsproxy-0.55.0/.codecov.ymlcoverage:
  status:
    project:
      default:
        target: 40%
        threshold: null
    patch: false
    changes: false
07070100000001000081A4000000000000000000000001650C592100000049000000000000000000000000000000000000001E00000000dnsproxy-0.55.0/.dockerignore# Ignore everything except for explicitly allowed stuff.
*
!build/docker
07070100000002000081A4000000000000000000000001650C592100000011000000000000000000000000000000000000001F00000000dnsproxy-0.55.0/.gitattributesvendor/** binary
07070100000003000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001800000000dnsproxy-0.55.0/.github07070100000004000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000002200000000dnsproxy-0.55.0/.github/workflows07070100000005000081A4000000000000000000000001650C592100000A5D000000000000000000000000000000000000002D00000000dnsproxy-0.55.0/.github/workflows/build.yamlname: Build

'env':
  'GO_VERSION': '1.20.8'

'on':
  'push':
    'tags':
      - 'v*'
    'branches':
      - '*'
  'pull_request':

jobs:
  tests:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os:
          - windows-latest
          - macos-latest
          - ubuntu-latest
    steps:
      - uses: actions/checkout@master
      - uses: actions/setup-go@v2
        with:
          go-version: '${{ env.GO_VERSION }}'
      - name: Run tests
        env:
          CI: "1"
        run: |-
          make test
      - name: Upload coverage
        uses: codecov/codecov-action@v1
        if: "success() && matrix.os == 'ubuntu-latest'"
        with:
          token: ${{ secrets.CODECOV_TOKEN }}
          file: ./coverage.txt

  build:
    needs:
      - tests
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@master
      - uses: actions/setup-go@v2
        with:
          go-version: '${{ env.GO_VERSION }}'
      - name: Build release
        run: |-
          set -e -u -x

          RELEASE_VERSION="${GITHUB_REF##*/}"
          if [[ "${RELEASE_VERSION}" != v* ]]; then RELEASE_VERSION='dev'; fi
          echo "RELEASE_VERSION=\"${RELEASE_VERSION}\"" >> $GITHUB_ENV

          make VERBOSE=1 VERSION="${RELEASE_VERSION}" release

          ls -l build/dnsproxy-*
      - name: Create release
        if: startsWith(github.ref, 'refs/tags/v')
        id: create_release
        uses: actions/create-release@v1
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          tag_name: ${{ github.ref }}
          release_name: Release ${{ github.ref }}
          draft: false
          prerelease: false
      - name: Upload
        if: startsWith(github.ref, 'refs/tags/v')
        uses: xresloader/upload-to-github-release@v1
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          file: "build/dnsproxy-*.tar.gz;build/dnsproxy-*.zip"
          tags: true
          draft: false

  notify:
    needs:
      - build
    if:
      ${{ always() &&
        (
          github.event_name == 'push' ||
          github.event.pull_request.head.repo.full_name == github.repository
        )
      }}
    runs-on: ubuntu-latest
    steps:
      - name: Conclusion
        uses: technote-space/workflow-conclusion-action@v1
      - name: Send Slack notif
        uses: 8398a7/action-slack@v3
        with:
          status: ${{ env.WORKFLOW_CONCLUSION }}
          fields: workflow, repo, message, commit, author, eventName,ref
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
          SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
07070100000006000081A4000000000000000000000001650C59210000090E000000000000000000000000000000000000002D00000000dnsproxy-0.55.0/.github/workflows/docker.yml'name': Docker

'env':
  'GO_VERSION': '1.20.8'

'on':
  'push':
    'tags':
      - 'v*'
    # Builds from the master branch will be pushed with the `dev` tag.
    'branches':
      - 'master'

'jobs':
  'docker':
    'runs-on': 'ubuntu-latest'
    'steps':
      - 'name': 'Checkout'
        'uses': 'actions/checkout@v3'
        'with':
          'fetch-depth': 0
      - 'name': 'Set up Go'
        'uses': 'actions/setup-go@v3'
        'with':
          'go-version': '${{ env.GO_VERSION }}'
      - 'name': 'Set up Go modules cache'
        'uses': 'actions/cache@v2'
        'with':
          'path': '~/go/pkg/mod'
          'key': "${{ runner.os }}-go-${{ hashFiles('go.sum') }}"
          'restore-keys': '${{ runner.os }}-go-'
      - 'name': 'Set up QEMU'
        'uses': 'docker/setup-qemu-action@v1'
      - 'name': 'Set up Docker Buildx'
        'uses': 'docker/setup-buildx-action@v1'
      - 'name': 'Publish to Docker Hub'
        'env':
          'DOCKER_USER': ${{ secrets.DOCKER_USER }}
          'DOCKER_PASSWORD': ${{ secrets.DOCKER_PASSWORD }}
        'run': |-
          set -e -u -x

          RELEASE_VERSION="${GITHUB_REF##*/}"
          if [[ "${RELEASE_VERSION}" != v* ]]; then RELEASE_VERSION='dev'; fi
          echo "RELEASE_VERSION=\"${RELEASE_VERSION}\"" >> $GITHUB_ENV

          docker login \
            -u="${DOCKER_USER}" \
            -p="${DOCKER_PASSWORD}"

          make \
            VERSION="${RELEASE_VERSION}" \
            DOCKER_IMAGE_NAME="adguard/dnsproxy" \
            DOCKER_OUTPUT="type=image,name=adguard/dnsproxy,push=true" \
            VERBOSE="1" \
            docker

  'notify':
    'needs':
      - 'docker'
    'if':
      ${{ always() &&
      (
      github.event_name == 'push' ||
      github.event.pull_request.head.repo.full_name == github.repository
      )
      }}
    'runs-on': ubuntu-latest
    'steps':
      - 'name': Conclusion
        'uses': technote-space/workflow-conclusion-action@v1
      - 'name': Send Slack notif
        'uses': 8398a7/action-slack@v3
        'with':
          'status': ${{ env.WORKFLOW_CONCLUSION }}
          'fields': workflow, repo, message, commit, author, eventName,ref
        'env':
          'GITHUB_TOKEN': ${{ secrets.GITHUB_TOKEN }}
          'SLACK_WEBHOOK_URL': ${{ secrets.SLACK_WEBHOOK_URL }}
07070100000007000081A4000000000000000000000001650C592100000593000000000000000000000000000000000000002C00000000dnsproxy-0.55.0/.github/workflows/lint.yaml'name': 'lint'

'env':
  'GO_VERSION': '1.20.8'

'on':
  'push':
    'tags':
      - 'v*'
    'branches':
      - '*'
  'pull_request':

'jobs':
  'go-lint':
    'runs-on': 'ubuntu-latest'
    'steps':
      - 'uses': 'actions/checkout@v2'
      - 'name': 'Set up Go'
        'uses': 'actions/setup-go@v3'
        'with':
          'go-version': '${{ env.GO_VERSION }}'
      - 'name': 'run-lint'
        'run': >
          make go-deps go-tools go-lint

  'notify':
    'needs':
      - 'go-lint'
    # Secrets are not passed to workflows that are triggered by a pull request
    # from a fork.
    #
    # Use always() to signal to the runner that this job must run even if the
    # previous ones failed.
    'if':
      ${{
      always() &&
      github.repository_owner == 'AdguardTeam' &&
      (
      github.event_name == 'push' ||
      github.event.pull_request.head.repo.full_name == github.repository
      )
      }}
    'runs-on': 'ubuntu-latest'
    'steps':
      - 'name': 'Conclusion'
        'uses': 'technote-space/workflow-conclusion-action@v1'
      - 'name': 'Send Slack notif'
        'uses': '8398a7/action-slack@v3'
        'with':
          'status': '${{ env.WORKFLOW_CONCLUSION }}'
          'fields': 'workflow, repo, message, commit, author, eventName, ref'
        'env':
          'GITHUB_TOKEN': '${{ secrets.GITHUB_TOKEN }}'
          'SLACK_WEBHOOK_URL': '${{ secrets.SLACK_WEBHOOK_URL }}'
07070100000008000081A4000000000000000000000001650C5921000001D3000000000000000000000000000000000000001B00000000dnsproxy-0.55.0/.gitignore# Please, DO NOT put your text editors' temporary files here.  The more are
# added, the harder it gets to maintain and manage projects' gitignores.  Put
# them into your global gitignore file instead.
#
# See https://stackoverflow.com/a/7335487/1892060.
#
# Only build, run, and test outputs here.  Sorted.  With negations at the
# bottom to make sure they take effect.
*.out
*.test
/bin/
build
dnsproxy
dnsproxy.exe
example.crt
example.key
coverage.txt
config.yaml
07070100000009000081A4000000000000000000000001650C592100000605000000000000000000000000000000000000001E00000000dnsproxy-0.55.0/.golangci.yml# options for analysis running
run:
  # default concurrency is a available CPU number
  concurrency: 4

  # timeout for analysis, e.g. 30s, 5m, default is 1m
  deadline: 2m

  # which files to skip: they will be analyzed, but issues from them
  # won't be reported. Default value is empty list, but there is
  # no need to include all autogenerated files, we confidently recognize
  # autogenerated files. If it's not please let us know.
  skip-files:
    - ".*generated.*"

# all available settings of specific linters
linters-settings:
  gocyclo:
    min-complexity: 20
  lll:
    line-length: 200

linters:
  enable:
    - errcheck
    - govet
    - ineffassign
    - staticcheck
    - unused
    - dupl
    - gocyclo
    - goimports
    - revive
    - gosec
    - misspell
    - stylecheck
    - unconvert
  disable-all: true
  fast: true

issues:
  exclude-use-default: false

  # List of regexps of issue texts to exclude, empty list by default.
  # But independently from this option we use default exclude patterns,
  # it can be disabled by `exclude-use-default: false`. To list all
  # excluded by default patterns execute `golangci-lint run --help`
  exclude:
    # gosec: Expect file permissions to be 0600 or less
    - G302
    # gosec: False positive is triggered by 'src, err := os.ReadFile(filename)'
    - Potential file inclusion via variable
    # gosec: TLS InsecureSkipVerify may be true
    # We have a configuration option that allows to do this
    - G402
    # gosec: Use of weak random number generator
    - G404
0707010000000A000081A4000000000000000000000001650C592100002C57000000000000000000000000000000000000001800000000dnsproxy-0.55.0/LICENSE
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2020 Adguard Software Ltd

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
0707010000000B000081A4000000000000000000000001650C592100000916000000000000000000000000000000000000001900000000dnsproxy-0.55.0/Makefile# Keep the Makefile POSIX-compliant.  We currently allow hyphens in
# target names, but that may change in the future.
#
# See https://pubs.opengroup.org/onlinepubs/9699919799/utilities/make.html.
.POSIX:

# Don't name this macro "GO", because GNU Make apparenly makes it an
# exported environment variable with the literal value of "${GO:-go}",
# which is not what we need.  Use a dot in the name to make sure that
# users don't have an environment variable with the same name.
#
# See https://unix.stackexchange.com/q/646255/105635.
GO.MACRO = $${GO:-go}
VERBOSE.MACRO = $${VERBOSE:-0}

BRANCH = $$( git rev-parse --abbrev-ref HEAD )
GOAMD64 = v1
GOPROXY = https://goproxy.cn|https://proxy.golang.org|direct
DIST_DIR = build
OUT = dnsproxy
RACE = 0
REVISION = $$( git rev-parse --short HEAD )
VERSION = 0

ENV = env\
	BRANCH="$(BRANCH)"\
	COMMIT='$(COMMIT)'\
	DIST_DIR='$(DIST_DIR)'\
	GO="$(GO.MACRO)"\
	GOAMD64='$(GOAMD64)'\
	GOPROXY='$(GOPROXY)'\
	OUT='$(OUT)'\
	PATH="$${PWD}/bin:$$( "$(GO.MACRO)" env GOPATH )/bin:$${PATH}"\
	RACE='$(RACE)'\
	REVISION="$(REVISION)"\
	VERBOSE="$(VERBOSE.MACRO)"\
	VERSION="$(VERSION)"\

# Keep the line above blank.

# Keep this target first, so that a naked make invocation triggers a
# full build.
build: go-deps go-build

init: ; git config core.hooksPath ./scripts/hooks

test: go-test

go-build: ; $(ENV)          "$(SHELL)" ./scripts/make/go-build.sh
go-deps:  ; $(ENV)          "$(SHELL)" ./scripts/make/go-deps.sh
go-lint:  ; $(ENV)          "$(SHELL)" ./scripts/make/go-lint.sh
go-test:  ; $(ENV) RACE='1' "$(SHELL)" ./scripts/make/go-test.sh
go-tools: ; $(ENV)          "$(SHELL)" ./scripts/make/go-tools.sh

go-check: go-tools go-lint go-test

# A quick check to make sure that all operating systems relevant to the
# development of the project can be typechecked and built successfully.
go-os-check:
	env GOOS='darwin'  "$(GO.MACRO)" vet ./...
	env GOOS='freebsd' "$(GO.MACRO)" vet ./...
	env GOOS='linux'   "$(GO.MACRO)" vet ./...
	env GOOS='openbsd' "$(GO.MACRO)" vet ./...
	env GOOS='windows' "$(GO.MACRO)" vet ./...

txt-lint: ; $(ENV) "$(SHELL)" ./scripts/make/txt-lint.sh

clean:   ; $(ENV) $(GO.MACRO) clean && rm -f -r '$(DIST_DIR)'

release: clean
	$(ENV) "$(SHELL)" ./scripts/make/build-release.sh

docker: release
	$(ENV) "$(SHELL)" ./scripts/make/build-docker.sh
0707010000000C000081A4000000000000000000000001650C592100003DC0000000000000000000000000000000000000001A00000000dnsproxy-0.55.0/README.md[![Code Coverage](https://img.shields.io/codecov/c/github/AdguardTeam/dnsproxy/master.svg)](https://codecov.io/github/AdguardTeam/dnsproxy?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/AdguardTeam/dnsproxy)](https://goreportcard.com/report/AdguardTeam/dnsproxy)
[![GolangCI](https://golangci.com/badges/github.com/AdguardTeam/dnsproxy.svg)](https://golangci.com/r/github.com/AdguardTeam/dnsproxy)
[![Go Doc](https://godoc.org/github.com/AdguardTeam/dnsproxy?status.svg)](https://godoc.org/github.com/AdguardTeam/dnsproxy)

# DNS Proxy <!-- omit in toc -->

A simple DNS proxy server that supports all existing DNS protocols including
`DNS-over-TLS`, `DNS-over-HTTPS`, `DNSCrypt`, and `DNS-over-QUIC`. Moreover,
it can work as a `DNS-over-HTTPS`, `DNS-over-TLS` or `DNS-over-QUIC` server.

- [How to install](#how-to-install)
- [How to build](#how-to-build)
- [Usage](#usage)
- [Examples](#examples)
  - [Simple options](#simple-options)
  - [Encrypted upstreams](#encrypted-upstreams)
  - [Encrypted DNS server](#encrypted-dns-server)
  - [Additional features](#additional-features)
  - [DNS64 server](#dns64-server)
  - [Fastest addr + cache-min-ttl](#fastest-addr--cache-min-ttl)
  - [Specifying upstreams for domains](#specifying-upstreams-for-domains)
  - [EDNS Client Subnet](#edns-client-subnet)
  - [Bogus NXDomain](#bogus-nxdomain)

## How to install

There are several options how to install `dnsproxy`.

1. Grab the binary for your device/OS from the [Releases][releases] page.
2. Use the [official Docker image][docker].
3. Build it yourself (see the instruction below).

[releases]: https://github.com/AdguardTeam/dnsproxy/releases
[docker]: https://hub.docker.com/r/adguard/dnsproxy

## How to build

You will need Go v1.20 or later.

```shell
$ make build
```

## Usage

```
Usage:
  dnsproxy [OPTIONS]

Application Options:
      --config-path=           yaml configuration file. Minimal working configuration in config.yaml.dist.
                               Options passed through command line will override the ones from this file.
  -v, --verbose                Verbose output (optional)
  -o, --output=                Path to the log file. If not set, write to stdout.
  -l, --listen=                Listening addresses
  -p, --port=                  Listening ports. Zero value disables TCP and UDP listeners
  -s, --https-port=            Listening ports for DNS-over-HTTPS
  -t, --tls-port=              Listening ports for DNS-over-TLS
  -q, --quic-port=             Listening ports for DNS-over-QUIC
  -y, --dnscrypt-port=         Listening ports for DNSCrypt
  -c, --tls-crt=               Path to a file with the certificate chain
  -k, --tls-key=               Path to a file with the private key
      --tls-min-version=       Minimum TLS version, for example 1.0
      --tls-max-version=       Maximum TLS version, for example 1.3
      --insecure               Disable secure TLS certificate validation
  -g, --dnscrypt-config=       Path to a file with DNSCrypt configuration. You can generate one using https://github.com/ameshkov/dnscrypt
      --http3                  Enable HTTP/3 support
  -u, --upstream=              An upstream to be used (can be specified multiple times).
                               You can also specify path to a file with the list of servers
  -b, --bootstrap=             Bootstrap DNS for DoH and DoT, can be specified multiple times (default: 8.8.8.8:53)
  -f, --fallback=              Fallback resolvers to use when regular ones are unavailable, can be specified multiple times.
                               You can also specify path to a file with the list of servers
      --private-rdns-upstream= Private DNS upstreams to use for reverse DNS lookups of private addresses, can
                               be specified multiple times
      --all-servers            If specified, parallel queries to all configured upstream servers are enabled
      --fastest-addr           Respond to A or AAAA requests only with the fastest IP address
      --timeout=               Timeout for outbound DNS queries to remote upstream servers in a
                               human-readable form (default: 10s)
      --cache                  If specified, DNS cache is enabled
      --cache-size=            Cache size (in bytes). Default: 64k
      --cache-min-ttl=         Minimum TTL value for DNS entries, in seconds. Capped at 3600.
                               Artificially extending TTLs should only be done with careful consideration.
      --cache-max-ttl=         Maximum TTL value for DNS entries, in seconds.
      --cache-optimistic       If specified, optimistic DNS cache is enabled
  -r, --ratelimit=             Ratelimit (requests per second)
      --refuse-any             If specified, refuse ANY requests
      --edns                   Use EDNS Client Subnet extension
      --edns-addr=             Send EDNS Client Address
      --dns64                  If specified, dnsproxy will act as a DNS64 server
      --dns64-prefix=          Prefix used to handle DNS64. If not specified, dnsproxy uses the 'Well-Known Prefix' 64:ff9b::.
                               Can be specified multiple times
      --https-server-name=     Set the Server header for the responses from the HTTPS server. (default: dnsproxy)
      --ipv6-disabled          If specified, all AAAA requests will be replied with NoError RCode and empty answer
      --bogus-nxdomain=        Transform the responses containing at least a single IP that matches specified addresses
                               and CIDRs into NXDOMAIN.  Can be specified multiple times.
      --udp-buf-size=          Set the size of the UDP buffer in bytes. A value <= 0 will use the system default.
      --max-go-routines=       Set the maximum number of go routines. A value <= 0 will not not set a maximum.
      --pprof                  If present, exposes pprof information on localhost:6060.
      --version                Prints the program version

Help Options:
  -h, --help              Show this help message
```

## Examples

### Simple options

Runs a DNS proxy on `0.0.0.0:53` with a single upstream - Google DNS.
```shell
./dnsproxy -u 8.8.8.8:53
```

The same proxy with verbose logging enabled writing it to the file `log.txt`.
```shell
./dnsproxy -u 8.8.8.8:53 -v -o log.txt
```

Runs a DNS proxy on `127.0.0.1:5353` with multiple upstreams.
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8:53 -u 1.1.1.1:53
```

Listen on multiple interfaces and ports:
```shell
./dnsproxy -l 127.0.0.1 -l 192.168.1.10 -p 5353 -p 5354 -u 1.1.1.1
```

The plain DNS upstream server may be specified in several ways:

 -  With a plain IP address:
    ```shell
    ./dnsproxy -l 127.0.0.1 -u 8.8.8.8:53
    ```

 -  With a hostname or plain IP address and the `udp://` scheme:
    ```shell
    ./dnsproxy -l 127.0.0.1 -u udp://dns.google -u udp://1.1.1.1
    ```

 -  With a hostname or plain IP address and the `tcp://` scheme to force using
    TCP:
    ```shell
    ./dnsproxy -l 127.0.0.1 -u tcp://dns.google -u tcp://1.1.1.1
    ```

### Encrypted upstreams

DNS-over-TLS upstream:
```shell
./dnsproxy -u tls://dns.adguard.com
```

DNS-over-HTTPS upstream with specified bootstrap DNS:
```shell
./dnsproxy -u https://dns.adguard.com/dns-query -b 1.1.1.1:53
```

DNS-over-QUIC upstream:
```shell
./dnsproxy -u quic://dns.adguard.com
```

DNS-over-HTTPS upstream with enabled HTTP/3 support (chooses it if it's faster):
```shell
./dnsproxy -u https://dns.google/dns-query --http3
```

DNS-over-HTTPS upstream with forced HTTP/3 (no fallback to other protocol):
```shell
./dnsproxy -u h3://dns.google/dns-query
```

DNSCrypt upstream ([DNS Stamp](https://dnscrypt.info/stamps) of AdGuard DNS):
```shell
./dnsproxy -u sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20
```

DNS-over-HTTPS upstream ([DNS Stamp](https://dnscrypt.info/stamps) of Cloudflare DNS):
```shell
./dnsproxy -u sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk
```

DNS-over-TLS upstream with two fallback servers (to be used when the main upstream is not available):
```shell
./dnsproxy -u tls://dns.adguard.com -f 8.8.8.8:53 -f 1.1.1.1:53
```

### Encrypted DNS server

Runs a DNS-over-TLS proxy on `127.0.0.1:853`.
```shell
./dnsproxy -l 127.0.0.1 --tls-port=853 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```

Runs a DNS-over-HTTPS proxy on `127.0.0.1:443`.
```shell
./dnsproxy -l 127.0.0.1 --https-port=443 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```

Runs a DNS-over-HTTPS proxy on `127.0.0.1:443` with HTTP/3 support.
```shell
./dnsproxy -l 127.0.0.1 --https-port=443 --http3 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```

Runs a DNS-over-QUIC proxy on `127.0.0.1:853`.
```shell
./dnsproxy -l 127.0.0.1 --quic-port=853 --tls-crt=example.crt --tls-key=example.key -u 8.8.8.8:53 -p 0
```

Runs a DNSCrypt proxy on `127.0.0.1:443`.

```shell
./dnsproxy -l 127.0.0.1 --dnscrypt-config=./dnscrypt-config.yaml --dnscrypt-port=443 --upstream=8.8.8.8:53 -p 0
```

> Please note that in order to run a DNSCrypt proxy, you need to obtain DNSCrypt configuration first. You can use https://github.com/ameshkov/dnscrypt command-line tool to do that with a command like this `./dnscrypt generate --provider-name=2.dnscrypt-cert.example.org --out=dnscrypt-config.yaml`

### Additional features

Runs a DNS proxy on `0.0.0.0:53` with rate limit set to `10 rps`, enabled DNS cache, and that refuses type=ANY requests.
```shell
./dnsproxy -u 8.8.8.8:53 -r 10 --cache --refuse-any
```

Runs a DNS proxy on 127.0.0.1:5353 with multiple upstreams and enable parallel queries to all configured upstream servers.
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8:53 -u 1.1.1.1:53 -u tls://dns.adguard.com --all-servers
```

Loads upstreams list from a file.
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u ./upstreams.txt
```

### DNS64 server

`dnsproxy` is capable of working as a DNS64 server.

> **What is DNS64/NAT64**
> This is a mechanism of providing IPv6 access to IPv4. Using a NAT64 gateway
> with IPv4-IPv6 translation capability lets IPv6-only clients connect to
> IPv4-only services via synthetic IPv6 addresses starting with a prefix that
> routes them to the NAT64 gateway. DNS64 is a DNS service that returns AAAA
> records with these synthetic IPv6 addresses for IPv4-only destinations
> (with A but not AAAA records in the DNS). This lets IPv6-only clients use
> NAT64 gateways without any other configuration.

See also [RFC 6147](https://datatracker.ietf.org/doc/html/rfc6147).

Enables DNS64 with the default [Well-Known Prefix][wkp]:
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8 --private-rdns-upstream=127.0.0.1  --dns64
```

You can also specify any number of custom DNS64 prefixes:
```shell
./dnsproxy -l 127.0.0.1 -p 5353 -u 8.8.8.8 --private-rdns-upstream=127.0.0.1 --dns64 --dns64-prefix=64:ffff:: --dns64-prefix=32:ffff::
```

Note that only the first specified prefix will be used for synthesis.

PTR queries for addresses within the specified ranges or the
[Well-Known one][wkp] could only be answered with locally appropriate data, so
dnsproxy will route those to the local upstream servers.  Those should be
specified if DNS64 is enabled.

[wkp]: https://datatracker.ietf.org/doc/html/rfc6052#section-2.1

### Fastest addr + cache-min-ttl

This option would be useful to the users with problematic network connection.
In this mode, `dnsproxy` would detect the fastest IP address among all that were returned,
and it will return only it.

Additionally, for those with problematic network connection, it makes sense to override `cache-min-ttl`.
In this case, `dnsproxy` will make sure that DNS responses are cached for at least the specified amount of time.

It makes sense to run it with multiple upstream servers only.

Run a DNS proxy with two upstreams, min-TTL set to 10 minutes, fastest address detection is enabled:
```
./dnsproxy -u 8.8.8.8 -u 1.1.1.1 --cache --cache-min-ttl=600 --fastest-addr
```

 who run `dnsproxy` with multiple upstreams

### Specifying upstreams for domains

You can specify upstreams that will be used for a specific domain(s). We use the dnsmasq-like syntax (see `--server` description [here](http://www.thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html)).

**Syntax:** `[/[domain1][/../domainN]/]upstreamString`

If one or more domains are specified, that upstream (`upstreamString`) is used only for those domains. Usually, it is used for private nameservers. For instance, if you have a nameserver on your network which deals with `xxx.internal.local` at `192.168.0.1` then you can specify `[/internal.local/]192.168.0.1`, and dnsproxy will send all queries to that nameserver. Everything else will be sent to the default upstreams (which are mandatory!).

1. An empty domain specification, // has the special meaning of "unqualified names only" ie names without any dots in them.
2. More specific domains take precedence over less specific domains, so: `--upstream=[/host.com/]1.2.3.4 --upstream=[/www.host.com/]2.3.4.5` will send queries for *.host.com to 1.2.3.4, except *.www.host.com, which will go to 2.3.4.5
3. The special server address `#` means, "use the standard servers", so: `--upstream=[/host.com/]1.2.3.4 --upstream=[/www.host.com/]#` will send queries for \*.host.com to 1.2.3.4, except \*.www.host.com which will be forwarded as usual.
4. The wildcard `*` has special meaning of "any sub-domain", so: `--upstream=[/*.host.com/]1.2.3.4` will send queries for \*.host.com to 1.2.3.4, but host.com will be forwarded to default upstreams.

**Examples**

Sends queries for `*.local` domains to `192.168.0.1:53`. Other queries are sent to `8.8.8.8:53`.
```
./dnsproxy -u 8.8.8.8:53 -u [/local/]192.168.0.1:53
```

Sends queries for `*.host.com` to `1.1.1.1:53` except for `*.maps.host.com` which are sent to `8.8.8.8:53` (along with other queries).
```
./dnsproxy -u 8.8.8.8:53 -u [/host.com/]1.1.1.1:53 -u [/maps.host.com/]#
```

Sends queries for `*.host.com` to `1.1.1.1:53` except for `host.com` which is sent to `8.8.8.8:53` (along with other queries).
```
./dnsproxy -u 8.8.8.8:53 -u [/*.host.com/]1.1.1.1:53
```

### EDNS Client Subnet

To enable support for EDNS Client Subnet extension you should run dnsproxy with `--edns` flag:

```
./dnsproxy -u 8.8.8.8:53 --edns
```

Now if you connect to the proxy from the Internet - it will pass through your original IP address's prefix to the upstream server.  This way the upstream server may respond with IP addresses of the servers that are located near you to minimize latency.

If you want to use EDNS CS feature when you're connecting to the proxy from a local network, you need to set `--edns-addr=PUBLIC_IP` argument:

```
./dnsproxy -u 8.8.8.8:53 --edns --edns-addr=72.72.72.72
```

Now even if your IP address is 192.168.0.1 and it's not a public IP, the proxy will pass through 72.72.72.72 to the upstream server.

### Bogus NXDomain

This option is similar to dnsmasq `bogus-nxdomain`.  `dnsproxy` will transform
responses that contain at least a single IP address which is also specified by
the option into `NXDOMAIN`. Can be specified multiple times.

In the example below, we use AdGuard DNS server that returns `0.0.0.0` for
blocked domains, and transform them to `NXDOMAIN`.

```
./dnsproxy -u 94.140.14.14:53 --bogus-nxdomain=0.0.0.0
```

CIDR ranges are supported as well.  The following will respond with `NXDOMAIN`
instead of responses containing any IP from `192.168.0.0`-`192.168.255.255`:

```
./dnsproxy -u 192.168.0.15:53 --bogus-nxdomain=192.168.0.0/16
```
0707010000000D000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001D00000000dnsproxy-0.55.0/bamboo-specs0707010000000E000081A4000000000000000000000001650C592100000968000000000000000000000000000000000000002900000000dnsproxy-0.55.0/bamboo-specs/bamboo.yaml---
'version': 2
'plan':
    'project-key': 'GO'
    'key': 'DNSPROXY'
    'name': 'dnsproxy - Build and run tests'
'variables':
    'dockerFpm': 'alanfranz/fpm-within-docker:ubuntu-bionic'
    # When there is a patch release of Go available, set this property to an
    # exact patch version as opposed to a minor one to make sure that this exact
    # version is actually used and not whatever the docker daemon on the CI has
    # cached a few months ago.
    'dockerGo': 'golang:1.20.8'
    'maintainer': 'Adguard Go Team'
    'name': 'dnsproxy'

'stages':
  - 'Lint':
      'manual': false
      'final': false
      'jobs':
        - 'Lint'
  - 'Test':
      'manual': false
      'final': false
      'jobs':
        - 'Test'

'Lint':
    'docker':
        'image': '${bamboo.dockerGo}'
        'volumes':
            '${system.GO_CACHE_DIR}': '${bamboo.cacheGo}'
            '${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}'
    'key': 'LINT'
    'other':
        'clean-working-dir': true
    'requirements':
      - 'adg-docker': true
    'tasks':
      - 'checkout':
             'force-clean-build': true
      - 'script':
              'interpreter': 'SHELL'
              'scripts':
                - |
                  set -e -f -u -x

                  make VERBOSE=1 go-tools go-lint

'Test':
    'docker':
        'image': '${bamboo.dockerGo}'
        'volumes':
            '${system.GO_CACHE_DIR}': '${bamboo.cacheGo}'
            '${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}'
    'key': 'TEST'
    'other':
        'clean-working-dir': true
    'requirements':
      - 'adg-docker': true
    'tasks':
      - 'checkout':
            'force-clean-build': true
      - 'script':
            'interpreter': 'SHELL'
            # Projects that have go-bench and/or go-fuzz targets should add them
            # here as well.
            'scripts':
              - |
                set -e -f -u -x

                make VERBOSE=1 go-deps go-test

'branches':
    'create': 'for-pull-request'
    'delete':
        'after-deleted-days': 1
        'after-inactive-days': 5
    'link-to-jira': true

'notifications':
  - 'events':
      - 'plan-status-changed'
    'recipients':
      - 'webhook':
            'name': 'Build webhook'
            'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo'

'labels': []

'other':
    'concurrent-build-plugin': 'system-default'
0707010000000F000081A4000000000000000000000001650C5921000001F4000000000000000000000000000000000000002100000000dnsproxy-0.55.0/config.yaml.dist# This is the yaml configuration file for dnsproxy with minimal working
# configuration, all the options available can be seen with ./dnsproxy --help.
# To use it within dnsproxy specify the --config-path=/<path-to-config.yaml>
# option.  Any other command-line options specified will override the values
# from the config file.
---
bootstrap:
  - "8.8.8.8:53"
listen-addrs:
  - "0.0.0.0"
listen-ports:
  - 53
max-go-routines: 0
ratelimit: 0
udp-buf-size: 0
upstream:
  - "1.1.1.1:53"
timeout: '10s'
07070100000010000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001700000000dnsproxy-0.55.0/docker07070100000011000081A4000000000000000000000001650C59210000072A000000000000000000000000000000000000002200000000dnsproxy-0.55.0/docker/Dockerfile# A docker file for scripts/make/build-docker.sh.

FROM alpine:3.18

ARG BUILD_DATE
ARG VERSION
ARG VCS_REF

LABEL\
	maintainer="AdGuard Team <devteam@adguard.com>" \
	org.opencontainers.image.authors="AdGuard Team <devteam@adguard.com>" \
	org.opencontainers.image.created=$BUILD_DATE \
	org.opencontainers.image.description="Simple DNS proxy with DoH, DoT, DoQ and DNSCrypt support" \
	org.opencontainers.image.documentation="https://github.com/AdguardTeam/dnsproxy" \
	org.opencontainers.image.licenses="Apache-2.0" \
	org.opencontainers.image.revision=$VCS_REF \
	org.opencontainers.image.source="https://github.com/AdguardTeam/dnsproxy" \
	org.opencontainers.image.title="dnsproxy" \
	org.opencontainers.image.url="https://github.com/AdguardTeam/dnsproxy" \
	org.opencontainers.image.vendor="AdGuard" \
	org.opencontainers.image.version=$VERSION

# Update certificates.
RUN apk --no-cache add ca-certificates libcap tzdata && \
	mkdir -p /opt/dnsproxy && chown -R nobody: /opt/dnsproxy

ARG DIST_DIR
ARG TARGETARCH
ARG TARGETOS
ARG TARGETVARIANT

COPY --chown=nobody:nogroup\
	./${DIST_DIR}/docker/dnsproxy_${TARGETOS}_${TARGETARCH}_${TARGETVARIANT}\
	/opt/dnsproxy/dnsproxy
COPY --chown=nobody:nogroup\
    ./${DIST_DIR}/docker/config.yaml\
    /opt/dnsproxy/config.yaml

RUN setcap 'cap_net_bind_service=+eip' /opt/dnsproxy/dnsproxy

# 53     : TCP, UDP : DNS
# 80     : TCP      : HTTP
# 443    : TCP, UDP : HTTPS, DNS-over-HTTPS (incl. HTTP/3), DNSCrypt (main)
# 853    : TCP, UDP : DNS-over-TLS, DNS-over-QUIC
# 5443   : TCP, UDP : DNSCrypt (alt)
# 6060   : TCP      : HTTP (pprof)
EXPOSE 53/tcp 53/udp \
       80/tcp \
       443/tcp 443/udp \
       853/tcp 853/udp \
       5443/tcp 5443/udp \
       6060/tcp

WORKDIR /opt/dnsproxy

ENTRYPOINT ["/opt/dnsproxy/dnsproxy"]
CMD ["--config-path=/opt/dnsproxy/config.yaml"]
07070100000012000081A4000000000000000000000001650C592100000490000000000000000000000000000000000000002100000000dnsproxy-0.55.0/docker/README.md# DNS Proxy

A simple DNS proxy server that supports all existing DNS protocols including
`DNS-over-TLS`, `DNS-over-HTTPS`, `DNSCrypt`, and `DNS-over-QUIC`. Moreover,
it can work as a `DNS-over-HTTPS`, `DNS-over-TLS` or `DNS-over-QUIC` server.

Learn more about dnsproxy and its full capabilities in
its [Github repo][dnsproxy].

[dnsproxy]: https://github.com/AdguardTeam/dnsproxy

## Quick start

### Pull the Docker image

This command will pull the latest stable version:

```shell
docker pull adguard/dnsproxy
```

### Run the container

Run the container with the default configuration (see `config.yaml.dist` in the
repository) and expose DNS ports.

```shell
docker run --name dnsproxy \
  -p 53:53/tcp -p 53:53/udp \
  adguard/dnsproxy
```

Run the container with command-line args configuration and expose DNS ports.

```shell
docker run --name dnsproxy_google_dns \
  -p 53:53/tcp -p 53:53/udp \
  adguard/dnsproxy \
  -u 8.8.8.8:53
```

Run the container with a configuration file and expose DNS ports.

```shell
docker run --name dnsproxy_google_dns \
  -p 53:53/tcp -p 53:53/udp \
  -v $PWD/config.yaml:/opt/dnsproxy/config.yaml \
  adguard/dnsproxy
```
07070100000013000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001700000000dnsproxy-0.55.0/fastip07070100000014000081A4000000000000000000000001650C5921000009CB000000000000000000000000000000000000002000000000dnsproxy-0.55.0/fastip/cache.gopackage fastip

import (
	"encoding/binary"
	"net/netip"
	"time"
)

// TODO(e.burkov):  Rewrite the cache using zero-values instead of storing
// useless boolean as an integer.

const (
	fastestAddrCacheTTLSec = 10 * 60 // cache TTL for IP addresses
)

type cacheEntry struct {
	status      int // 0:ok; 1:timed out
	latencyMsec uint
}

// packCacheEntry packs the cache entry and the TTL to bytes in the following
// order:
//
//	expire   [4]byte  (Unix time, seconds)
//	status   byte     (0 for ok, 1 for timed out)
//	latency  [2]byte  (milliseconds)
func packCacheEntry(ent *cacheEntry, ttl uint32) []byte {
	expire := uint32(time.Now().Unix()) + ttl

	d := make([]byte, 4+1+2)
	binary.BigEndian.PutUint32(d, expire)
	i := 4

	d[i] = byte(ent.status)
	i++

	binary.BigEndian.PutUint16(d[i:], uint16(ent.latencyMsec))
	// i += 2

	return d
}

// unpackCacheEntry - unpacks bytes to cache entry and checks TTL
// if the record is expired, returns nil
func unpackCacheEntry(data []byte) *cacheEntry {
	now := time.Now().Unix()
	expire := binary.BigEndian.Uint32(data[:4])
	if int64(expire) <= now {
		return nil
	}
	ent := cacheEntry{}
	i := 4

	ent.status = int(data[i])
	i++

	ent.latencyMsec = uint(binary.BigEndian.Uint16(data[i:]))
	// i += 2

	return &ent
}

// cacheFind - find entry in the cache for this IP
// returns null if nothing found or if the record for this ip is expired
func (f *FastestAddr) cacheFind(ip netip.Addr) *cacheEntry {
	val := f.ipCache.Get(ip.AsSlice())
	if val == nil {
		return nil
	}
	ent := unpackCacheEntry(val)
	if ent == nil {
		return nil
	}
	return ent
}

// cacheAddFailure - store unsuccessful attempt in cache
func (f *FastestAddr) cacheAddFailure(ip netip.Addr) {
	ent := cacheEntry{
		status: 1,
	}

	f.ipCacheLock.Lock()
	defer f.ipCacheLock.Unlock()

	if f.cacheFind(ip) == nil {
		f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
	}
}

// store a successful ping result in cache
// replace previous result if our latency is lower
func (f *FastestAddr) cacheAddSuccessful(ip netip.Addr, latency uint) {
	ent := cacheEntry{
		latencyMsec: latency,
	}

	f.ipCacheLock.Lock()
	defer f.ipCacheLock.Unlock()

	entCached := f.cacheFind(ip)
	if entCached == nil || entCached.status != 0 || entCached.latencyMsec > latency {
		f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
	}
}

// cacheAdd -- adds a new entry to the cache
func (f *FastestAddr) cacheAdd(ent *cacheEntry, ip netip.Addr, ttl uint32) {
	val := packCacheEntry(ent, ttl)
	f.ipCache.Set(ip.AsSlice(), val)
}
07070100000015000081A4000000000000000000000001650C59210000083B000000000000000000000000000000000000002500000000dnsproxy-0.55.0/fastip/cache_test.gopackage fastip

import (
	"net"
	"net/netip"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
)

func TestCacheAdd(t *testing.T) {
	f := NewFastestAddr()
	ent := cacheEntry{
		status:      0,
		latencyMsec: 111,
	}

	ip := netip.MustParseAddr("1.1.1.1")
	f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)

	// check that it's there
	assert.NotNil(t, f.cacheFind(ip))
}

func TestCacheTtl(t *testing.T) {
	f := NewFastestAddr()
	ent := cacheEntry{
		status:      0,
		latencyMsec: 111,
	}

	ip := netip.MustParseAddr("1.1.1.1")
	f.cacheAdd(&ent, ip, 1)

	// check that it's there
	assert.NotNil(t, f.cacheFind(ip))

	// wait for more than one second
	time.Sleep(time.Millisecond * 1001)

	// check that now it returns nil
	assert.Nil(t, f.cacheFind(ip))
}

func TestCacheAddSuccessfulOverwrite(t *testing.T) {
	f := NewFastestAddr()

	ip := netip.MustParseAddr("1.1.1.1")
	f.cacheAddFailure(ip)

	// check that it's there
	ent := f.cacheFind(ip)
	assert.NotNil(t, ent)
	assert.Equal(t, 1, ent.status)

	// check that it will overwrite existing rec
	f.cacheAddSuccessful(ip, 11)

	// check that it's there now
	ent = f.cacheFind(ip)
	assert.NotNil(t, ent)
	assert.Equal(t, 0, ent.status)
	assert.Equal(t, uint(11), ent.latencyMsec)
}

func TestCacheAddFailureNoOverwrite(t *testing.T) {
	f := NewFastestAddr()

	ip := netip.MustParseAddr("1.1.1.1")
	f.cacheAddSuccessful(ip, 11)

	// check that it's there
	ent := f.cacheFind(ip)
	assert.NotNil(t, ent)
	assert.Equal(t, 0, ent.status)

	// check that it will overwrite existing rec
	f.cacheAddFailure(ip)

	// check that the old record is still there
	ent = f.cacheFind(ip)
	assert.NotNil(t, ent)
	assert.Equal(t, 0, ent.status)
	assert.Equal(t, uint(11), ent.latencyMsec)
}

// TODO(ameshkov): Actually test something.
func TestCache(_ *testing.T) {
	f := NewFastestAddr()
	ent := cacheEntry{
		status:      0,
		latencyMsec: 111,
	}

	val := packCacheEntry(&ent, 1)
	f.ipCache.Set(net.ParseIP("1.1.1.1").To4(), val)
	ent = cacheEntry{
		status:      0,
		latencyMsec: 222,
	}

	f.cacheAdd(&ent, netip.MustParseAddr("2.2.2.2"), fastestAddrCacheTTLSec)
}
07070100000016000081A4000000000000000000000001650C5921000010B7000000000000000000000000000000000000002200000000dnsproxy-0.55.0/fastip/fastest.go// Package fastip implements the algorithm that allows to query multiple
// resolvers, ping all IP addresses that were returned, and return the fastest
// one among them.
package fastip

import (
	"net"
	"net/netip"
	"strings"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/cache"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
	"golang.org/x/exp/maps"
)

// DefaultPingWaitTimeout is the default period of time for waiting ping
// operations to finish.
const DefaultPingWaitTimeout = 1 * time.Second

// FastestAddr provides methods to determine the fastest network addresses.
type FastestAddr struct {
	// pinger is the dialer with predefined timeout for pinging TCP
	// connections.
	pinger *net.Dialer

	// ipCacheLock protects ipCache.
	ipCacheLock *sync.Mutex

	// ipCache caches fastest IP addresses.
	ipCache cache.Cache

	// pingPorts are the ports to ping on.
	pingPorts []uint

	// PingWaitTimeout is the timeout for waiting all the resolved addresses
	// are pinged.  Any ping results received after it are cached but not
	// used at the moment.  It should be configured right after the
	// FastestAddr initialization since it isn't protected for concurrent
	// usage.
	PingWaitTimeout time.Duration
}

// NewFastestAddr initializes a new instance of the *FastestAddr.
func NewFastestAddr() (f *FastestAddr) {
	return &FastestAddr{
		ipCacheLock: &sync.Mutex{},
		ipCache: cache.New(cache.Config{
			MaxSize:   64 * 1024,
			EnableLRU: true,
		}),
		pingPorts:       []uint{80, 443},
		PingWaitTimeout: DefaultPingWaitTimeout,
		pinger:          &net.Dialer{Timeout: pingTCPTimeout},
	}
}

// ExchangeFastest queries each specified upstream and returns a response with
// the fastest IP address.  The fastest IP address is considered to be the first
// one successfully dialed and other addresses are removed from the answer.
func (f *FastestAddr) ExchangeFastest(req *dns.Msg, ups []upstream.Upstream) (
	resp *dns.Msg,
	u upstream.Upstream,
	err error,
) {
	replies, err := upstream.ExchangeAll(ups, req)
	if err != nil {
		return nil, nil, err
	}

	host := strings.ToLower(req.Question[0].Name)

	ipSet := map[netip.Addr]struct{}{}
	for _, r := range replies {
		for _, rr := range r.Resp.Answer {
			ip := ipFromRR(rr)
			if _, ok := ipSet[ip]; !ok && ip != (netip.Addr{}) {
				ipSet[ip] = struct{}{}
			}
		}
	}

	ips := maps.Keys(ipSet)
	if pingRes := f.pingAll(host, ips); pingRes != nil {
		return f.prepareReply(pingRes, replies)
	}

	log.Debug("%s: no fastest IP found, using the first response", host)

	return replies[0].Resp, replies[0].Upstream, nil
}

// prepareReply converts replies into the DNS answer message according to res.
// The returned upstreams is the one which replied with the fastest address.
func (f *FastestAddr) prepareReply(
	res *pingResult,
	replies []upstream.ExchangeAllResult,
) (resp *dns.Msg, u upstream.Upstream, err error) {
	ip := res.addrPort.Addr()
	for _, r := range replies {
		if hasInAns(r.Resp, ip) {
			resp = r.Resp
			u = r.Upstream

			break
		}
	}

	if resp == nil {
		log.Error("found no replies with IP %s, most likely this is a bug", ip)

		return replies[0].Resp, replies[0].Upstream, nil
	}

	// Modify the message and keep only A and AAAA records containing the
	// fastest IP address.
	ans := make([]dns.RR, 0, len(resp.Answer))
	ipBytes := ip.AsSlice()
	for _, rr := range resp.Answer {
		switch addr := rr.(type) {
		case *dns.A:
			if addr.A.Equal(ipBytes) {
				ans = append(ans, rr)
			}
		case *dns.AAAA:
			if addr.AAAA.Equal(ipBytes) {
				ans = append(ans, rr)
			}
		default:
			ans = append(ans, rr)
		}
	}

	// Set new answer.
	resp.Answer = ans

	return resp, u, nil
}

// hasInAns returns true if m contains ip in its Answer section.
func hasInAns(m *dns.Msg, ip netip.Addr) (ok bool) {
	for _, rr := range m.Answer {
		respIP := ipFromRR(rr)
		if respIP == ip {
			return true
		}
	}

	return false
}

// ipFromRR returns the IP address from rr if any.
func ipFromRR(rr dns.RR) (ip netip.Addr) {
	switch rr := rr.(type) {
	case *dns.A:
		ip, _ = netutil.IPToAddr(rr.A, netutil.AddrFamilyIPv4)
	case *dns.AAAA:
		ip, _ = netutil.IPToAddr(rr.AAAA, netutil.AddrFamilyIPv6)
	}

	return ip
}
07070100000017000081A4000000000000000000000001650C592100000E78000000000000000000000000000000000000002700000000dnsproxy-0.55.0/fastip/fastest_test.gopackage fastip

import (
	"net"
	"net/netip"
	"testing"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestFastestAddr_ExchangeFastest(t *testing.T) {
	t.Run("error", func(t *testing.T) {
		const errDesired errors.Error = "this is expected"

		u := &errUpstream{
			err: errDesired,
		}
		f := NewFastestAddr()

		resp, up, err := f.ExchangeFastest(testARequest(t), []upstream.Upstream{u})
		require.Error(t, err)

		assert.ErrorIs(t, err, errDesired)
		assert.Nil(t, resp)
		assert.Nil(t, up)
	})

	t.Run("one_dead", func(t *testing.T) {
		port := listen(t, netip.IPv4Unspecified())

		f := NewFastestAddr()
		f.pingPorts = []uint{port}

		// The alive IP is the just created local listener's address.  The dead
		// one is known as TEST-NET-1 which shouldn't be routed at all.  See
		// RFC-5737 (https://datatracker.ietf.org/doc/html/rfc5737).
		aliveIP, deadIP := net.IP{127, 0, 0, 1}, net.IP{192, 0, 2, 1}
		alive := new(testAUpstream).add(t.Name(), aliveIP)
		dead := new(testAUpstream).add(t.Name(), deadIP)

		rep, up, err := f.ExchangeFastest(testARequest(t), []upstream.Upstream{dead, alive})
		require.NoError(t, err)

		assert.Equal(t, up, alive)

		require.NotNil(t, rep)
		require.NotEmpty(t, rep.Answer)
		require.IsType(t, new(dns.A), rep.Answer[0])

		ip := rep.Answer[0].(*dns.A).A
		assert.True(t, aliveIP.Equal(ip))
	})

	t.Run("all_dead", func(t *testing.T) {
		f := NewFastestAddr()
		f.pingPorts = []uint{getFreePort(t)}

		firstIP := net.IP{127, 0, 0, 1}
		up1 := new(testAUpstream).
			add(t.Name(), firstIP).
			add(t.Name(), net.IP{127, 0, 0, 2}).
			add(t.Name(), net.IP{127, 0, 0, 3})

		resp, _, err := f.ExchangeFastest(testARequest(t), []upstream.Upstream{up1})
		require.NoError(t, err)

		require.NotNil(t, resp)
		require.NotEmpty(t, resp.Answer)
		require.IsType(t, new(dns.A), resp.Answer[0])

		ip := resp.Answer[0].(*dns.A).A
		assert.True(t, firstIP.Equal(ip))
	})
}

type errUpstream struct {
	err      error
	closeErr error
}

// Address implements the [upstream.Upstream] interface for errUpstream.
func (u *errUpstream) Address() string {
	return "bad_upstream"
}

// Exchange implements the [upstream.Upstream] interface for errUpstream.
func (u *errUpstream) Exchange(_ *dns.Msg) (*dns.Msg, error) {
	return nil, u.err
}

// Close implements the [upstream.Upstream] interface for errUpstream.
func (u *errUpstream) Close() error {
	return u.closeErr
}

type testAUpstream struct {
	recs []*dns.A
}

// type check
var _ upstream.Upstream = (*testAUpstream)(nil)

// Exchange implements the upstream.Upstream interface for *testAUpstream.
func (u *testAUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
	resp = &dns.Msg{}
	resp.SetReply(m)

	for _, a := range u.recs {
		resp.Answer = append(resp.Answer, a)
	}

	return resp, nil
}

// Address implements the upstream.Upstream interface for *testAUpstream.
func (u *testAUpstream) Address() (addr string) {
	return ""
}

// Close implements the upstream.Upstream interface for *testAUpstream.
func (u *testAUpstream) Close() (err error) {
	return nil
}

func (u *testAUpstream) add(host string, ip net.IP) (chain *testAUpstream) {
	u.recs = append(u.recs, &dns.A{
		Hdr: dns.RR_Header{
			Rrtype: dns.TypeA,
			Name:   dns.Fqdn(host),
			Ttl:    60,
		},
		A: ip,
	})

	return u
}

func testARequest(t *testing.T) (req *dns.Msg) {
	return &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id:               dns.Id(),
			RecursionDesired: true,
		},
		Question: []dns.Question{{
			Name:   dns.Fqdn(t.Name()),
			Qtype:  dns.TypeA,
			Qclass: dns.ClassINET,
		}},
	}
}
07070100000018000081A4000000000000000000000001650C592100000D25000000000000000000000000000000000000001F00000000dnsproxy-0.55.0/fastip/ping.gopackage fastip

import (
	"net/netip"
	"time"

	"github.com/AdguardTeam/golibs/log"
)

// pingTCPTimeout is a TCP connection timeout.  It's higher than pingWaitTimeout
// since the slower connections will be cached anyway.
const pingTCPTimeout = 4 * time.Second

// pingResult is the result of dialing the address.
type pingResult struct {
	// addrPort is the address-port pair the result is related to.
	addrPort netip.AddrPort
	// latency is the duration of dialing process in milliseconds.
	latency uint
	// success is true when the dialing succeeded.
	success bool
}

// pingAll pings all ips concurrently and returns as soon as the fastest one is
// found or the timeout is exceeded.
func (f *FastestAddr) pingAll(host string, ips []netip.Addr) (pr *pingResult) {
	ipN := len(ips)
	switch ipN {
	case 0:
		return nil
	case 1:
		return &pingResult{
			addrPort: netip.AddrPortFrom(ips[0], 0),
			success:  true,
		}
	}

	portN := len(f.pingPorts)
	resCh := make(chan *pingResult, ipN*portN)
	scheduled := 0

	// Find the fastest cached IP address and start pinging others.
	for _, ip := range ips {
		cached := f.cacheFind(ip)
		if cached == nil {
			for _, port := range f.pingPorts {
				go f.pingDoTCP(host, netip.AddrPortFrom(ip, uint16(port)), resCh)
			}
			scheduled += portN

			continue
		} else if cached.status != 0 {
			continue
		}

		if pr == nil || cached.latencyMsec < pr.latency {
			pr = &pingResult{
				addrPort: netip.AddrPortFrom(ip, 0),
				latency:  cached.latencyMsec,
				success:  true,
			}
		}
	}

	cached := pr != nil
	if scheduled == 0 {
		if cached {
			log.Debug("pingAll: %s: return cached response: %s", host, pr.addrPort)
		} else {
			log.Debug("pingAll: %s: returning nothing", host)
		}

		return pr
	}

	// Wait for the first successful ping result or the timeout.
	for i, after := 0, time.After(f.PingWaitTimeout); i < scheduled; i++ {
		select {
		case res := <-resCh:
			log.Debug(
				"pingAll: %s: got result for %s status %v",
				host,
				res.addrPort,
				res.success,
			)
			if !res.success {
				continue
			}

			if !cached || pr.latency >= res.latency {
				pr = res
			}

			return pr
		case <-after:
			if cached {
				log.Debug(
					"pingAll: %s: pinging timed out, returning cached: %s",
					host,
					pr.addrPort,
				)
			} else {
				log.Debug(
					"pingAll: %s: ping checks timed out, returning nothing",
					host,
				)
			}

			return pr
		}
	}

	return pr
}

// pingDoTCP sends the result of dialing the specified address into resCh.
func (f *FastestAddr) pingDoTCP(host string, addrPort netip.AddrPort, resCh chan *pingResult) {
	log.Debug("pingDoTCP: %s: connecting to %s", host, addrPort)

	start := time.Now()
	conn, err := f.pinger.Dial("tcp", addrPort.String())
	elapsed := time.Since(start)

	success := err == nil
	if success {
		if cerr := conn.Close(); cerr != nil {
			log.Debug("closing tcp connection: %s", cerr)
		}
	}

	latency := uint(elapsed.Milliseconds())

	resCh <- &pingResult{
		addrPort: addrPort,
		latency:  latency,
		success:  success,
	}

	addr := addrPort.Addr().Unmap()
	if success {
		log.Debug("pingDoTCP: %s: elapsed %s ms on %s", host, elapsed, addrPort)
		f.cacheAddSuccessful(addr, latency)
	} else {
		log.Debug(
			"pingDoTCP: %s: failed to connect to %s, elapsed %s ms: %v",
			host,
			addrPort,
			elapsed,
			err,
		)
		f.cacheAddFailure(addr)
	}
}
07070100000019000081A4000000000000000000000001650C5921000014B5000000000000000000000000000000000000002400000000dnsproxy-0.55.0/fastip/ping_test.gopackage fastip

import (
	"net"
	"net/netip"
	"runtime"
	"sync"
	"syscall"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// unit is the convenient alias for struct{}.
type unit = struct{}

func TestFastestAddr_PingAll_timeout(t *testing.T) {
	t.Run("isolated", func(t *testing.T) {
		f := NewFastestAddr()

		waitCh := make(chan unit)
		f.pinger.Control = func(_, _ string, _ syscall.RawConn) error {
			<-waitCh

			return nil
		}

		ip := netutil.IPv4Localhost()
		res := f.pingAll("", []netip.Addr{ip, ip})
		require.Nil(t, res)

		waitCh <- unit{}
	})

	t.Run("cached", func(t *testing.T) {
		f := NewFastestAddr()

		const lat uint = 42

		ip1 := netutil.IPv4Localhost()
		ip2 := netip.MustParseAddr("127.0.0.2")
		f.cacheAddSuccessful(ip1, lat)

		waitCh := make(chan unit)
		f.pinger.Control = func(_, _ string, _ syscall.RawConn) error {
			<-waitCh

			return nil
		}

		res := f.pingAll("", []netip.Addr{ip1, ip2})
		require.NotNil(t, res)

		assert.True(t, res.success)
		assert.Equal(t, lat, res.latency)

		waitCh <- unit{}
	})
}

// assertCaching checks the cache of f for containing a connection to ip with
// the specified status.
func assertCaching(t *testing.T, f *FastestAddr, ip netip.Addr, status int) {
	t.Helper()

	const tickDur = pingTCPTimeout / 16

	assert.Eventually(t, func() bool {
		ce := f.cacheFind(ip)

		return ce != nil && ce.status == status
	}, pingTCPTimeout, tickDur)
}

func TestFastestAddr_PingAll_cache(t *testing.T) {
	ip := netutil.IPv4Localhost()

	t.Run("cached_failed", func(t *testing.T) {
		f := NewFastestAddr()
		f.cacheAddFailure(ip)

		res := f.pingAll("", []netip.Addr{ip, ip})
		require.Nil(t, res)
	})

	t.Run("cached_successful", func(t *testing.T) {
		const lat uint = 1

		f := NewFastestAddr()
		f.cacheAddSuccessful(ip, lat)

		res := f.pingAll("", []netip.Addr{ip, ip})
		require.NotNil(t, res)
		assert.True(t, res.success)
		assert.Equal(t, lat, res.latency)
	})

	t.Run("not_cached", func(t *testing.T) {
		listener, err := net.Listen("tcp", "127.0.0.1:0")
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, listener.Close)

		ip = netutil.IPv4Localhost()
		f := NewFastestAddr()

		f.pingPorts = []uint{uint(listener.Addr().(*net.TCPAddr).Port)}
		ips := []netip.Addr{ip, ip}

		wg := &sync.WaitGroup{}
		wg.Add(len(ips) * len(f.pingPorts))

		f.pinger.Control = func(_, address string, _ syscall.RawConn) (err error) {
			hostport, err := netutil.ParseHostPort(address)
			require.NoError(t, err)

			assert.Equal(t, ip.String(), hostport.Host)
			assert.Contains(t, f.pingPorts, uint(hostport.Port))

			wg.Done()

			return nil
		}

		res := f.pingAll("", ips)
		require.NotNil(t, res)

		assert.True(t, res.success)
		assertCaching(t, f, ip, 0)

		wg.Wait()
	})
}

// listen is a helper function that creates a new listener on ip for t.
func listen(t *testing.T, ip netip.Addr) (port uint) {
	t.Helper()

	l, err := net.Listen("tcp", netip.AddrPortFrom(ip, 0).String())
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, l.Close)

	return uint(l.Addr().(*net.TCPAddr).Port)
}

func TestFastestAddr_PingAll(t *testing.T) {
	ip := netutil.IPv4Localhost()

	t.Run("single", func(t *testing.T) {
		f := NewFastestAddr()
		res := f.pingAll("", []netip.Addr{ip})
		require.NotNil(t, res)

		assert.True(t, res.success)
		assert.Equal(t, ip, res.addrPort.Addr())
		// There was no ping so the port is zero.
		assert.Zero(t, res.addrPort.Port())

		// Nothing in the cache since there was no ping.
		ce := f.cacheFind(res.addrPort.Addr())
		require.Nil(t, ce)
	})

	t.Run("fastest", func(t *testing.T) {
		fastPort := listen(t, ip)
		slowPort := listen(t, ip)

		ctrlCh := make(chan unit, 1)

		f := NewFastestAddr()
		f.pingPorts = []uint{
			fastPort,
			slowPort,
		}
		f.pinger.Control = func(_, address string, _ syscall.RawConn) error {
			addrPort := netip.MustParseAddrPort(address)
			require.Contains(t, []uint{fastPort, slowPort}, uint(addrPort.Port()))
			if addrPort.Port() == uint16(fastPort) {
				return nil
			}

			<-ctrlCh

			return nil
		}

		ips := []netip.Addr{ip, ip}
		res := f.pingAll("", ips)
		ctrlCh <- unit{}

		require.NotNil(t, res)

		assert.True(t, res.success)
		assert.Equal(t, ip, res.addrPort.Addr())
		assert.EqualValues(t, fastPort, res.addrPort.Port())

		assertCaching(t, f, ip, 0)
	})

	t.Run("zero", func(t *testing.T) {
		res := NewFastestAddr().pingAll("", nil)
		require.Nil(t, res)
	})

	t.Run("fail", func(t *testing.T) {
		port := getFreePort(t)

		f := NewFastestAddr()
		f.pingPorts = []uint{port}

		res := f.pingAll("test", []netip.Addr{ip, ip})
		require.Nil(t, res)

		assertCaching(t, f, ip, 1)
	})
}

// getFreePort returns the port number no one listens on.
//
// TODO(e.burkov):  The logic is underwhelming.  Find a more accurate way.
func getFreePort(t *testing.T) (port uint) {
	t.Helper()

	l, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(t, err)

	port = uint(l.Addr().(*net.TCPAddr).Port)

	// Stop listening immediately.
	require.NoError(t, l.Close())

	// Sleeping for some time may be necessary on Windows.
	if runtime.GOOS == "windows" {
		time.Sleep(100 * time.Millisecond)
	}

	return port
}
0707010000001A000081A4000000000000000000000001650C5921000005F5000000000000000000000000000000000000001700000000dnsproxy-0.55.0/go.modmodule github.com/AdguardTeam/dnsproxy

go 1.20

require (
	github.com/AdguardTeam/golibs v0.16.2
	github.com/ameshkov/dnscrypt/v2 v2.2.7
	github.com/ameshkov/dnsstamps v1.0.3
	github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0
	github.com/bluele/gcache v0.0.2
	github.com/jessevdk/go-flags v1.5.0
	github.com/miekg/dns v1.1.56
	github.com/patrickmn/go-cache v2.1.0+incompatible
	github.com/quic-go/quic-go v0.38.1
	github.com/stretchr/testify v1.8.4
	golang.org/x/exp v0.0.0-20230905200255-921286631fa9
	golang.org/x/net v0.15.0
	golang.org/x/sys v0.12.0
	gopkg.in/yaml.v3 v3.0.1
)

require (
	github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
	github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect
	github.com/davecgh/go-spew v1.1.1 // indirect
	github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
	github.com/golang/mock v1.6.0 // indirect
	github.com/google/pprof v0.0.0-20230912144702-c363fe2c2ed8 // indirect
	github.com/kr/text v0.2.0 // indirect
	github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
	github.com/onsi/ginkgo/v2 v2.12.1 // indirect
	github.com/pmezard/go-difflib v1.0.0 // indirect
	github.com/quic-go/qpack v0.4.0 // indirect
	github.com/quic-go/qtls-go1-20 v0.3.4 // indirect
	golang.org/x/crypto v0.13.0 // indirect
	golang.org/x/mod v0.12.0 // indirect
	golang.org/x/text v0.13.0 // indirect
	golang.org/x/tools v0.13.0 // indirect
	gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
)
0707010000001B000081A4000000000000000000000001650C59210000244D000000000000000000000000000000000000001700000000dnsproxy-0.55.0/go.sumgithub.com/AdguardTeam/golibs v0.16.2 h1:54286tqaGZl3L13EV1PbaMnGqJkFJdaVtqFpDNEKZi8=
github.com/AdguardTeam/golibs v0.16.2/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us=
github.com/ameshkov/dnscrypt/v2 v2.2.7 h1:aEitLIR8HcxVodZ79mgRcCiC0A0I5kZPBuWGFwwulAw=
github.com/ameshkov/dnscrypt/v2 v2.2.7/go.mod h1:qPWhwz6FdSmuK7W4sMyvogrez4MWdtzosdqlr0Rg3ow=
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
github.com/ameshkov/dnsstamps v1.0.3/go.mod h1:Ii3eUu73dx4Vw5O4wjzmT5+lkCwovjzaEZZ4gKyIH5A=
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 h1:0b2vaepXIfMsG++IsjHiI2p4bxALD1Y2nQKGMR5zDQM=
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA=
github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw=
github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/pprof v0.0.0-20230912144702-c363fe2c2ed8 h1:gpptm606MZYGaMHMsB4Srmb6EbW/IVHnt04rcMXnkBQ=
github.com/google/pprof v0.0.0-20230912144702-c363fe2c2ed8/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/miekg/dns v1.1.56 h1:5imZaSeoRNvpM9SzWNhEcP9QliKiz20/dA2QabIGVnE=
github.com/miekg/dns v1.1.56/go.mod h1:cRm6Oo2C8TY9ZS/TqsSrseAcncm74lfK5G+ikN2SWWY=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/onsi/ginkgo/v2 v2.12.1 h1:uHNEO1RP2SpuZApSkel9nEh1/Mu+hmQe7Q+Pepg5OYA=
github.com/onsi/ginkgo/v2 v2.12.1/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg=
github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/quic-go/quic-go v0.38.1 h1:M36YWA5dEhEeT+slOu/SwMEucbYd0YFidxG3KlGPZaE=
github.com/quic-go/quic-go v0.38.1/go.mod h1:ijnZM7JsFIkp4cRyjxJNIzdSfCLmUMg9wdyhGmg+SN4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
0707010000001C000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001900000000dnsproxy-0.55.0/internal0707010000001D000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000002300000000dnsproxy-0.55.0/internal/bootstrap0707010000001E000081A4000000000000000000000001650C592100000CAE000000000000000000000000000000000000003000000000dnsproxy-0.55.0/internal/bootstrap/bootstrap.go// Package bootstrap provides types and functions to resolve upstream hostnames
// and to dial retrieved addresses.
package bootstrap

import (
	"context"
	"fmt"
	"net"
	"net/netip"
	"net/url"
	"time"

	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
)

// DialHandler is a dial function for creating unencrypted network connections
// to the upstream server.  It establishes the connection to the server
// specified at initialization and ignores the addr.
type DialHandler func(ctx context.Context, network, addr string) (conn net.Conn, err error)

// ResolveDialContext returns a DialHandler that uses addresses resolved from
// u using resolvers.  u must not be nil.
func ResolveDialContext(
	u *url.URL,
	timeout time.Duration,
	resolvers []Resolver,
	preferIPv6 bool,
) (h DialHandler, err error) {
	defer func() { err = errors.Annotate(err, "dialing %q: %w", u.Host) }()

	host, port, err := netutil.SplitHostPort(u.Host)
	if err != nil {
		// Don't wrap the error since it's informative enough as is and there is
		// already deferred annotation here.
		return nil, err
	}

	ctx := context.Background()
	if timeout > 0 {
		var cancel func()
		ctx, cancel = context.WithTimeout(ctx, timeout)
		defer cancel()
	}

	ips, err := LookupParallel(ctx, resolvers, host)
	if err != nil {
		return nil, fmt.Errorf("resolving hostname: %w", err)
	}

	proxynetutil.SortNetIPAddrs(ips, preferIPv6)

	addrs := make([]string, 0, len(ips))
	for _, ip := range ips {
		if !ip.IsValid() {
			// All invalid addresses should be in the tail after sorting.
			break
		}

		addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)).String())
	}

	return NewDialContext(timeout, addrs...), nil
}

// NewDialContext returns a DialHandler that dials addrs and returns the first
// successful connection.  At least a single addr should be specified.
//
// TODO(e.burkov):  Consider using [Resolver] instead of
// [upstream.Options.Bootstrap] and [upstream.Options.ServerIPAddrs].
func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) {
	dialer := &net.Dialer{
		Timeout: timeout,
	}

	l := len(addrs)
	if l == 0 {
		log.Debug("bootstrap: no addresses to dial")

		return func(_ context.Context, _, _ string) (conn net.Conn, err error) {
			return nil, errors.Error("no addresses")
		}
	}

	// TODO(e.burkov):  Check IPv6 preference here.

	return func(ctx context.Context, network, _ string) (conn net.Conn, err error) {
		var errs []error

		// Return first succeeded connection.  Note that we're using addrs
		// instead of what's passed to the function.
		for i, addr := range addrs {
			log.Debug("bootstrap: dialing %s (%d/%d)", addr, i+1, l)

			start := time.Now()
			conn, err = dialer.DialContext(ctx, network, addr)
			elapsed := time.Since(start)
			if err == nil {
				log.Debug("bootstrap: connection to %s succeeded in %s", addr, elapsed)

				return conn, nil
			}

			log.Debug("bootstrap: connection to %s failed in %s: %s", addr, elapsed, err)
			errs = append(errs, err)
		}

		// TODO(e.burkov):  Use errors.Join in Go 1.20.
		return nil, errors.List("all dialers failed", errs...)
	}
}
0707010000001F000081A4000000000000000000000001650C592100001086000000000000000000000000000000000000003500000000dnsproxy-0.55.0/internal/bootstrap/bootstrap_test.gopackage bootstrap_test

import (
	"context"
	"net"
	"net/netip"
	"net/url"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// testTimeout is a common timeout used in tests of this package.
const testTimeout = 1 * time.Second

// newListener creates a new listener of zero address of the specified network
// type and returns it, adding it's closing to the test cleanup.  sig is used to
// send the address of each accepted connection and must be read properly.
func newListener(t testing.TB, network string, sig chan net.Addr) (ipp netip.AddrPort) {
	t.Helper()

	// TODO(e.burkov):  Listen IPv6 as well, when the CI adds IPv6 interfaces.
	l, err := net.Listen(network, "127.0.0.1:0")
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, l.Close)

	go func() {
		pt := testutil.PanicT{}
		for c, lerr := l.Accept(); !errors.Is(lerr, net.ErrClosed); c, lerr = l.Accept() {
			require.NoError(pt, lerr)

			testutil.RequireSend(pt, sig, c.LocalAddr(), testTimeout)

			require.NoError(pt, c.Close())
		}
	}()

	ipp, err = netip.ParseAddrPort(l.Addr().String())
	require.NoError(t, err)

	return ipp
}

// See the details here: https://github.com/AdguardTeam/dnsproxy/issues/18
func TestResolveDialContext(t *testing.T) {
	sig := make(chan net.Addr, 1)

	ipp := newListener(t, "tcp", sig)
	port := ipp.Port()

	testCases := []struct {
		name       string
		addresses  []netip.Addr
		preferIPv6 bool
	}{{
		name:       "v4",
		addresses:  []netip.Addr{netutil.IPv4Localhost()},
		preferIPv6: false,
	}, {
		name:       "both_prefer_v6",
		addresses:  []netip.Addr{netutil.IPv4Localhost(), netutil.IPv6Localhost()},
		preferIPv6: true,
	}, {
		name:       "both_prefer_v4",
		addresses:  []netip.Addr{netutil.IPv6Localhost(), netutil.IPv4Localhost()},
		preferIPv6: false,
	}, {
		name:       "strip_invalid",
		addresses:  []netip.Addr{{}, netutil.IPv4Localhost(), {}, netutil.IPv6Localhost(), {}},
		preferIPv6: true,
	}}

	const hostname = "host.name"

	pt := testutil.PanicT{}

	for _, tc := range testCases {
		r := &testResolver{
			onLookupNetIP: func(
				_ context.Context,
				network string,
				host string,
			) (addrs []netip.Addr, err error) {
				require.Equal(pt, "ip", network)
				require.Equal(pt, hostname, host)

				return tc.addresses, nil
			},
		}

		t.Run(tc.name, func(t *testing.T) {
			dialContext, err := bootstrap.ResolveDialContext(
				&url.URL{Host: netutil.JoinHostPort(hostname, port)},
				testTimeout,
				[]bootstrap.Resolver{r},
				tc.preferIPv6,
			)
			require.NoError(t, err)

			conn, err := dialContext(context.Background(), "tcp", "")
			require.NoError(t, err)

			expected, ok := testutil.RequireReceive(t, sig, testTimeout)
			require.True(t, ok)

			assert.Equal(t, expected.String(), conn.RemoteAddr().String())
		})
	}

	t.Run("no_addresses", func(t *testing.T) {
		r := &testResolver{
			onLookupNetIP: func(
				_ context.Context,
				network string,
				host string,
			) (addrs []netip.Addr, err error) {
				require.Equal(pt, "ip", network)
				require.Equal(pt, hostname, host)

				return nil, nil
			},
		}

		dialContext, err := bootstrap.ResolveDialContext(
			&url.URL{Host: netutil.JoinHostPort(hostname, port)},
			testTimeout,
			[]bootstrap.Resolver{r},
			false,
		)
		require.NoError(t, err)

		_, err = dialContext(context.Background(), "tcp", "")
		testutil.AssertErrorMsg(t, "no addresses", err)
	})

	t.Run("bad_hostname", func(t *testing.T) {
		const errMsg = `dialing "bad hostname": address bad hostname: ` +
			`missing port in address`

		dialContext, err := bootstrap.ResolveDialContext(
			&url.URL{Host: "bad hostname"},
			testTimeout,
			nil,
			false,
		)
		testutil.AssertErrorMsg(t, errMsg, err)

		assert.Nil(t, dialContext)
	})

	t.Run("no_resolvers", func(t *testing.T) {
		dialContext, err := bootstrap.ResolveDialContext(
			&url.URL{Host: netutil.JoinHostPort(hostname, port)},
			testTimeout,
			nil,
			false,
		)
		assert.ErrorIs(t, err, bootstrap.ErrNoResolvers)
		assert.Nil(t, dialContext)
	})
}
07070100000020000081A4000000000000000000000001650C5921000009F6000000000000000000000000000000000000003400000000dnsproxy-0.55.0/internal/bootstrap/hostsresolver.gopackage bootstrap

import (
	"context"
	"fmt"
	"io/fs"
	"net/netip"

	"github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/hostsfile"
	"github.com/AdguardTeam/golibs/log"
	"golang.org/x/exp/slices"
)

// HostsResolver is a [Resolver] that uses [netutil.Hosts] as a source of IP.
type HostsResolver struct {
	// addrs is an actual source of IP addresses.
	addrs map[string][]netip.Addr
}

// NewHostsResolver is the resolver based on system hosts files.
func NewHostsResolver(hosts *netutil.Hosts) (hr *HostsResolver) {
	hr = &HostsResolver{}
	_, hr.addrs = hosts.Mappings()

	return hr
}

// NewDefaultHostsResolver returns a resolver based on system hosts files
// provided by the [hostsfile.DefaultHostsPaths] and read from rootFSys.
//
// TODO(e.burkov):  Use.
func NewDefaultHostsResolver(rootFSys fs.FS) (hr *HostsResolver, err error) {
	paths, err := hostsfile.DefaultHostsPaths()
	if err != nil {
		return nil, fmt.Errorf("getting default hosts paths: %w", err)
	}

	hosts, _ := netutil.NewHosts()
	for _, name := range paths {
		err = parseHostsFile(rootFSys, hosts, name)
		if err != nil {
			// Don't wrap the error since it's already informative enough as is.
			return nil, err
		}
	}

	return NewHostsResolver(hosts), nil
}

// parseHostsFile reads a single hosts file from fsys and parses it into hosts.
func parseHostsFile(fsys fs.FS, hosts *netutil.Hosts, name string) (err error) {
	f, err := fsys.Open(name)
	if err != nil {
		if errors.Is(err, fs.ErrNotExist) {
			log.Debug("hosts file %q doesn't exist", name)

			return nil
		}

		// Don't wrap the error since it's already informative enough as is.
		return err
	}

	// TODO(e.burkov):  Use [errors.Join] when it will be supported by all
	// dependencies.
	defer func() { err = errors.WithDeferred(err, f.Close()) }()

	return hostsfile.Parse(hosts, f, nil)
}

// type check
var _ Resolver = (*HostsResolver)(nil)

// LookupNetIP implements the [Resolver] interface for *hostsResolver.
func (hr *HostsResolver) LookupNetIP(
	context context.Context,
	network string,
	host string,
) (addrs []netip.Addr, err error) {
	var checkIP func(netip.Addr) (ok bool)
	switch network {
	case "ip4":
		addrs, checkIP = slices.Clone(hr.addrs[host]), netip.Addr.Is6
	case "ip6":
		addrs, checkIP = slices.Clone(hr.addrs[host]), netip.Addr.Is4
	case "ip":
		return slices.Clone(hr.addrs[host]), nil
	default:
		return nil, fmt.Errorf("unsupported network %q", network)
	}

	return slices.DeleteFunc(addrs, checkIP), nil
}
07070100000021000081A4000000000000000000000001650C592100000898000000000000000000000000000000000000003900000000dnsproxy-0.55.0/internal/bootstrap/hostsresolver_test.gopackage bootstrap_test

import (
	"context"
	"net/netip"
	"strings"
	"testing"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestHostsResolver_LookupNetIP(t *testing.T) {
	const hostsData = `
1.2.3.4 host1 host2 ipv4.only
::1 host1 host2 ipv6.only
`

	var (
		v4Addr = netip.MustParseAddr("1.2.3.4")
		v6Addr = netip.MustParseAddr("::1")
	)

	hosts, err := netutil.NewHosts(strings.NewReader(hostsData))
	require.NoError(t, err)

	hr := bootstrap.NewHostsResolver(hosts)

	testCases := []struct {
		name      string
		host      string
		net       string
		wantAddrs []netip.Addr
	}{{
		name:      "canonical_any",
		host:      "host1",
		net:       "ip",
		wantAddrs: []netip.Addr{v4Addr, v6Addr},
	}, {
		name:      "canonical_v4",
		host:      "host1",
		net:       "ip4",
		wantAddrs: []netip.Addr{v4Addr},
	}, {
		name:      "canonical_v6",
		host:      "host1",
		net:       "ip6",
		wantAddrs: []netip.Addr{v6Addr},
	}, {
		name:      "alias_any",
		host:      "host2",
		net:       "ip",
		wantAddrs: []netip.Addr{v4Addr, v6Addr},
	}, {
		name:      "alias_v4",
		host:      "host2",
		net:       "ip4",
		wantAddrs: []netip.Addr{v4Addr},
	}, {
		name:      "alias_v6",
		host:      "host2",
		net:       "ip6",
		wantAddrs: []netip.Addr{v6Addr},
	}, {
		name:      "unknown_host",
		host:      "host3",
		net:       "ip",
		wantAddrs: nil,
	}, {
		name:      "family_mismatch_v4",
		host:      "ipv6.only",
		net:       "ip4",
		wantAddrs: []netip.Addr{},
	}, {
		name:      "family_mismatch_v6",
		host:      "ipv4.only",
		net:       "ip6",
		wantAddrs: []netip.Addr{},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			var addrs []netip.Addr
			addrs, err = hr.LookupNetIP(context.Background(), tc.net, tc.host)
			require.NoError(t, err)

			assert.Equal(t, tc.wantAddrs, addrs)
		})
	}

	t.Run("unsupported_network", func(t *testing.T) {
		_, err = hr.LookupNetIP(context.Background(), "ip5", "host1")
		testutil.AssertErrorMsg(t, `unsupported network "ip5"`, err)
	})
}
07070100000022000081A4000000000000000000000001650C592100000989000000000000000000000000000000000000002F00000000dnsproxy-0.55.0/internal/bootstrap/resolver.gopackage bootstrap

import (
	"context"
	"net"
	"net/netip"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
)

// Resolver resolves the hostnames to IP addresses.
type Resolver interface {
	// LookupNetIP looks up the IP addresses for the given host.  network must
	// be one of "ip", "ip4" or "ip6".
	LookupNetIP(ctx context.Context, network, host string) (addrs []netip.Addr, err error)
}

// type check
var _ Resolver = &net.Resolver{}

// ErrNoResolvers is returned when zero resolvers specified.
const ErrNoResolvers errors.Error = "no resolvers specified"

// LookupParallel performs lookup for IP address of host with all resolvers
// concurrently.
func LookupParallel(
	ctx context.Context,
	resolvers []Resolver,
	host string,
) (addrs []netip.Addr, err error) {
	resolversNum := len(resolvers)
	switch resolversNum {
	case 0:
		return nil, ErrNoResolvers
	case 1:
		return lookup(ctx, resolvers[0], host)
	default:
		// Go on.
	}

	// Size of channel must accommodate results of lookups from all resolvers,
	// sending into channel will be block otherwise.
	ch := make(chan *lookupResult, resolversNum)
	for _, res := range resolvers {
		go lookupAsync(ctx, res, host, ch)
	}

	var errs []error
	for range resolvers {
		result := <-ch
		if result.err == nil {
			return result.addrs, nil
		}

		errs = append(errs, result.err)
	}

	// TODO(e.burkov):  Use [errors.Join] in Go 1.20.
	return nil, errors.List("all resolvers failed", errs...)
}

// lookupResult is a structure that represents the result of a lookup.
type lookupResult struct {
	err   error
	addrs []netip.Addr
}

// lookupAsync tries to lookup for ip of host with r and sends the result into
// resCh.  It's inteneded to be used as a goroutine.
func lookupAsync(ctx context.Context, r Resolver, host string, resCh chan<- *lookupResult) {
	defer log.OnPanic("parallel lookup")

	addrs, err := lookup(ctx, r, host)
	resCh <- &lookupResult{
		err:   err,
		addrs: addrs,
	}
}

// lookup tries to lookup ip of host with r.
func lookup(ctx context.Context, r Resolver, host string) (addrs []netip.Addr, err error) {
	start := time.Now()
	addrs, err = r.LookupNetIP(ctx, "ip", host)
	elapsed := time.Since(start)
	if err != nil {
		log.Debug("parallel lookup: lookup for %s failed in %s: %s", host, elapsed, err)
	} else {
		log.Debug("parallel lookup: lookup for %s succeeded in %s: %s", host, elapsed, addrs)
	}

	return addrs, err
}
07070100000023000081A4000000000000000000000001650C592100000A83000000000000000000000000000000000000003400000000dnsproxy-0.55.0/internal/bootstrap/resolver_test.gopackage bootstrap_test

import (
	"context"
	"fmt"
	"net/netip"
	"testing"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// testResolver is the [Resolver] interface implementation for testing purposes.
type testResolver struct {
	onLookupNetIP func(ctx context.Context, network, host string) (addrs []netip.Addr, err error)
}

// LookupNetIP implements the [Resolver] interface for *testResolver.
func (r *testResolver) LookupNetIP(
	ctx context.Context,
	network string,
	host string,
) (addrs []netip.Addr, err error) {
	return r.onLookupNetIP(ctx, network, host)
}

func TestLookupParallel(t *testing.T) {
	const hostname = "host.name"

	t.Run("no_resolvers", func(t *testing.T) {
		addrs, err := bootstrap.LookupParallel(context.Background(), nil, "")
		assert.ErrorIs(t, err, bootstrap.ErrNoResolvers)
		assert.Nil(t, addrs)
	})

	pt := testutil.PanicT{}
	hostAddrs := []netip.Addr{netutil.IPv4Localhost()}

	immediate := &testResolver{
		onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
			require.Equal(pt, hostname, host)
			require.Equal(pt, "ip", network)

			return hostAddrs, nil
		},
	}

	t.Run("one_resolver", func(t *testing.T) {
		addrs, err := bootstrap.LookupParallel(
			context.Background(),
			[]bootstrap.Resolver{immediate},
			hostname,
		)
		require.NoError(t, err)

		assert.Equal(t, hostAddrs, addrs)
	})

	t.Run("two_resolvers", func(t *testing.T) {
		delayCh := make(chan struct{}, 1)
		delayed := &testResolver{
			onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
				require.Equal(pt, hostname, host)
				require.Equal(pt, "ip", network)

				testutil.RequireReceive(pt, delayCh, testTimeout)

				return []netip.Addr{netutil.IPv6Localhost()}, nil
			},
		}

		addrs, err := bootstrap.LookupParallel(
			context.Background(),
			[]bootstrap.Resolver{immediate, delayed},
			hostname,
		)
		require.NoError(t, err)
		testutil.RequireSend(t, delayCh, struct{}{}, testTimeout)

		assert.Equal(t, hostAddrs, addrs)
	})

	t.Run("all_errors", func(t *testing.T) {
		err := assert.AnError
		wantErrMsg := fmt.Sprintf("all resolvers failed: 3 errors: %[1]q, %[1]q, %[1]q", err)

		r := &testResolver{
			onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
				return nil, assert.AnError
			},
		}

		addrs, err := bootstrap.LookupParallel(
			context.Background(),
			[]bootstrap.Resolver{r, r, r},
			hostname,
		)
		testutil.AssertErrorMsg(t, wantErrMsg, err)
		assert.Nil(t, addrs)
	})
}
07070100000024000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000002100000000dnsproxy-0.55.0/internal/netutil07070100000025000081A4000000000000000000000001650C592100000F2E000000000000000000000000000000000000002A00000000dnsproxy-0.55.0/internal/netutil/hosts.gopackage netutil

import (
	"fmt"
	"io"
	"net/netip"
	"strings"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/hostsfile"
	"github.com/AdguardTeam/golibs/log"
	"golang.org/x/exp/slices"
)

// unit is a convenient alias for empty struct.
type unit = struct{}

// set is a helper type that removes duplicates.
type set[K string | netip.Addr] map[K]unit

// orderedSet is a helper type for storing values in original adding order and
// dealing with duplicates.
type orderedSet[K string | netip.Addr] struct {
	set  set[K]
	vals []K
}

// add adds val to os if it's not already there.
func (os *orderedSet[K]) add(key, val K) {
	if _, ok := os.set[key]; !ok {
		os.set[key] = unit{}
		os.vals = append(os.vals, val)
	}
}

// Convenience aliases for [orderedSet].
type (
	namesSet = orderedSet[string]
	addrsSet = orderedSet[netip.Addr]
)

// Hosts is a [hostsfile.HandleSet] that removes duplicates.
//
// It must be initialized with [NewHosts].
//
// TODO(e.burkov):  Think of storing only slices.
type Hosts struct {
	// names maps each address to its names in original case and in original
	// adding order without duplicates.
	names map[netip.Addr]*namesSet

	// addrs maps each host to its addresses in original adding order without
	// duplicates.
	addrs map[string]*addrsSet
}

// type check
var _ hostsfile.HandleSet = (*Hosts)(nil)

// NewHosts parses hosts files from r and returns a new Hosts set.  readers are
// optional, the error is only returned in case of parsing error.
func NewHosts(readers ...io.Reader) (h *Hosts, err error) {
	h = &Hosts{
		names: map[netip.Addr]*namesSet{},
		addrs: map[string]*addrsSet{},
	}

	for i, r := range readers {
		if err = hostsfile.Parse(h, r, nil); err != nil {
			return nil, fmt.Errorf("reader at index %d: %w", i, err)
		}
	}

	return h, nil
}

// type check
var _ hostsfile.HandleSet = (*Hosts)(nil)

// Add implements the [hostsfile.Set] interface for *Hosts.
func (h *Hosts) Add(rec *hostsfile.Record) {
	names := h.names[rec.Addr]
	if names == nil {
		names = &namesSet{set: set[string]{}}
		h.names[rec.Addr] = names
	}

	for _, name := range rec.Names {
		lowered := strings.ToLower(name)
		names.add(lowered, name)

		addrs := h.addrs[lowered]
		if addrs == nil {
			addrs = &addrsSet{
				vals: []netip.Addr{},
				set:  set[netip.Addr]{},
			}
			h.addrs[lowered] = addrs
		}
		addrs.add(rec.Addr, rec.Addr)
	}
}

// HandleInvalid implements the [hostsfile.HandleSet] interface for *Hosts.
func (h *Hosts) HandleInvalid(srcName string, _ []byte, err error) {
	lineErr := &hostsfile.LineError{}
	if !errors.As(err, &lineErr) {
		log.Debug("hostset: unexpected error from hostsfile: %s", err)

		return
	}

	if errors.Is(err, hostsfile.ErrEmptyLine) {
		// Ignore empty lines and comments.
		return
	}

	log.Debug("hostset: source %q: %s", srcName, lineErr)
}

// ByAddr returns each host for addr in original case, in original adding order
// without duplicates.  It returns nil if h doesn't contain the addr.
func (h *Hosts) ByAddr(addr netip.Addr) (hosts []string) {
	if hostsSet, ok := h.names[addr]; ok {
		return hostsSet.vals
	}

	return nil
}

// ByName returns each address for host in original adding order without
// duplicates.  It returns nil if h doesn't contain the host.
func (h *Hosts) ByName(host string) (addrs []netip.Addr) {
	if addrsSet, ok := h.addrs[strings.ToLower(host)]; ok {
		return addrsSet.vals
	}

	return nil
}

// Mappings returns a deep clone of the internal mappings.
func (h *Hosts) Mappings() (names map[netip.Addr][]string, addrs map[string][]netip.Addr) {
	names = make(map[netip.Addr][]string, len(h.names))
	addrs = make(map[string][]netip.Addr, len(h.addrs))

	for addr, namesSet := range h.names {
		names[addr] = slices.Clone(namesSet.vals)
	}

	for name, addrsSet := range h.addrs {
		addrs[name] = slices.Clone(addrsSet.vals)
	}

	return names, addrs
}
07070100000026000081A4000000000000000000000001650C592100000B7F000000000000000000000000000000000000002F00000000dnsproxy-0.55.0/internal/netutil/hosts_test.gopackage netutil_test

import (
	"io/fs"
	"net/netip"
	"os"
	"path"
	"testing"

	"github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"golang.org/x/exp/maps"
	"golang.org/x/exp/slices"
)

// testdata is an [fs.FS] containing data for tests.
var testdata = os.DirFS("./testdata")

func TestHosts(t *testing.T) {
	t.Parallel()

	var h *netutil.Hosts
	var err error
	t.Run("good_file", func(t *testing.T) {
		var f fs.File
		f, err = testdata.Open(path.Join(t.Name(), "hosts"))
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, f.Close)

		h, err = netutil.NewHosts(f)
	})
	require.NoError(t, err)

	// Variables mirroring the testdata/TestHosts/hosts file.
	var (
		v4Addr1 = netip.MustParseAddr("0.0.0.1")
		v4Addr2 = netip.MustParseAddr("0.0.0.2")

		mappedAddr1 = netip.MustParseAddr("::ffff:0.0.0.1")
		mappedAddr2 = netip.MustParseAddr("::ffff:0.0.0.2")

		v6Addr1 = netip.MustParseAddr("::1")
		v6Addr2 = netip.MustParseAddr("::2")

		wantHosts = map[string][]netip.Addr{
			"host.one":       {v4Addr1, mappedAddr1, v6Addr1},
			"host.two":       {v4Addr2, mappedAddr2, v6Addr2},
			"host.new":       {v4Addr2, v4Addr1, mappedAddr2, mappedAddr1, v6Addr2, v6Addr1},
			"again.host.two": {v4Addr2, mappedAddr2, v6Addr2},
		}

		wantAddrs = map[netip.Addr][]string{
			v4Addr1:     {"Host.One", "host.new"},
			v4Addr2:     {"Host.Two", "Host.New", "Again.Host.Two"},
			mappedAddr1: {"Host.One", "host.new"},
			mappedAddr2: {"Host.Two", "Host.New", "Again.Host.Two"},
			v6Addr1:     {"Host.One", "host.new"},
			v6Addr2:     {"Host.Two", "Host.New", "Again.Host.Two"},
		}
	)

	t.Run("Mappings", func(t *testing.T) {
		names, addrs := h.Mappings()
		assert.Equal(t, wantAddrs, names)
		assert.Equal(t, wantHosts, addrs)
	})

	t.Run("ByAddr", func(t *testing.T) {
		t.Parallel()

		// Sort keys to make the test deterministic.
		addrs := maps.Keys(wantAddrs)
		slices.SortFunc(addrs, netip.Addr.Compare)

		for _, addr := range addrs {
			addr := addr
			t.Run(addr.String(), func(t *testing.T) {
				t.Parallel()

				assert.Equal(t, wantAddrs[addr], h.ByAddr(addr))
			})
		}
	})

	t.Run("ByHost", func(t *testing.T) {
		t.Parallel()

		// Sort keys to make the test deterministic.
		hosts := maps.Keys(wantHosts)
		slices.Sort(hosts)

		for _, host := range hosts {
			host := host
			t.Run(host, func(t *testing.T) {
				t.Parallel()

				assert.Equal(t, wantHosts[host], h.ByName(host))
			})
		}
	})

	t.Run("bad_file", func(t *testing.T) {
		var f fs.File
		f, err = testdata.Open(path.Join(t.Name(), "hosts"))
		require.NoError(t, err)
		testutil.CleanupAndRequireSuccess(t, f.Close)

		_, err = netutil.NewHosts(f)
		require.NoError(t, err)
	})

	t.Run("non-line_error", func(t *testing.T) {
		assert.NotPanics(t, func() {
			(&netutil.Hosts{}).HandleInvalid("test", nil, assert.AnError)
		})
	})
}
07070100000027000081A4000000000000000000000001650C5921000001FA000000000000000000000000000000000000003100000000dnsproxy-0.55.0/internal/netutil/listenconfig.gopackage netutil

import "net"

// ListenConfig returns the default [net.ListenConfig] used by the plain-DNS
// servers in this module.
//
// TODO(a.garipov): Add tests.
//
// TODO(a.garipov): Add an option to not set SO_REUSEPORT on Unix to prevent
// issues with OpenWrt.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/5872.
//
// TODO(a.garipov): DRY with AdGuard DNS when we can.
func ListenConfig() (lc *net.ListenConfig) {
	return &net.ListenConfig{
		Control: defaultListenControl,
	}
}
07070100000028000081A4000000000000000000000001650C59210000045E000000000000000000000000000000000000003600000000dnsproxy-0.55.0/internal/netutil/listenconfig_unix.go//go:build unix

package netutil

import (
	"fmt"
	"syscall"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"golang.org/x/sys/unix"
)

// defaultListenControl is used as a [net.ListenConfig.Control] function to set
// the SO_REUSEADDR and SO_REUSEPORT socket options on all sockets used by the
// DNS servers in this module.
func defaultListenControl(_, _ string, c syscall.RawConn) (err error) {
	var opErr error
	err = c.Control(func(fd uintptr) {
		opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
		if opErr != nil {
			opErr = fmt.Errorf("setting SO_REUSEADDR: %w", opErr)

			return
		}

		opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
		if opErr != nil {
			if errors.Is(opErr, unix.ENOPROTOOPT) {
				// Some Linux OSs do not seem to support SO_REUSEPORT, including
				// some varieties of OpenWrt.  Issue a warning.
				log.Info("warning: SO_REUSEPORT not supported: %s", opErr)
				opErr = nil
			} else {
				opErr = fmt.Errorf("setting SO_REUSEPORT: %w", opErr)
			}
		}
	})

	return errors.WithDeferred(opErr, err)
}
07070100000029000081A4000000000000000000000001650C5921000000D6000000000000000000000000000000000000003900000000dnsproxy-0.55.0/internal/netutil/listenconfig_windows.go//go:build windows

package netutil

import "syscall"

// defaultListenControl is nil on Windows, because it doesn't support
// SO_REUSEPORT.
var defaultListenControl func(_, _ string, _ syscall.RawConn) (_ error)
0707010000002A000081A4000000000000000000000001650C59210000071A000000000000000000000000000000000000002C00000000dnsproxy-0.55.0/internal/netutil/netutil.go// Package netutil contains network-related utilities common among dnsproxy
// packages.
//
// TODO(a.garipov): Move improved versions of these into netutil in module
// golibs.
package netutil

import (
	"net"
	"net/netip"

	glnetutil "github.com/AdguardTeam/golibs/netutil"
	"golang.org/x/exp/slices"
)

// SortIPAddrs sorts addrs in accordance with the protocol preferences.  Invalid
// addresses are sorted near the end.  Zones are ignored.
//
// TODO(a.garipov): Use netip.Addr instead of net.IPAddr everywhere where this
// is called.
func SortIPAddrs(addrs []net.IPAddr, preferIPv6 bool) {
	l := len(addrs)
	if l <= 1 {
		return
	}

	slices.SortStableFunc(addrs, func(addrA, addrB net.IPAddr) (res int) {
		// Assume that len(addrs) is mostly small, so these conversions aren't
		// as expensive as they could have been.
		a, err := glnetutil.IPToAddrNoMapped(addrA.IP)
		if err != nil {
			return 1
		}

		b, err := glnetutil.IPToAddrNoMapped(addrB.IP)
		if err != nil {
			return -1
		}

		aIs4, bIs4 := a.Is4(), b.Is4()
		if aIs4 == bIs4 {
			return a.Compare(b)
		}

		if aIs4 {
			if preferIPv6 {
				return 1
			}

			return -1
		}

		if preferIPv6 {
			return -1
		}

		return 1
	})
}

// SortNetIPAddrs sorts addrs in accordance with the protocol preferences.
// Invalid addresses are sorted near the end.  Zones are ignored.
func SortNetIPAddrs(addrs []netip.Addr, preferIPv6 bool) {
	l := len(addrs)
	if l <= 1 {
		return
	}

	slices.SortStableFunc(addrs, func(addrA, addrB netip.Addr) (res int) {
		if !addrA.IsValid() {
			return 1
		} else if !addrB.IsValid() {
			return -1
		}

		aIs4, bIs4 := addrA.Is4(), addrB.Is4()
		if aIs4 == bIs4 {
			return addrA.Compare(addrB)
		}

		if aIs4 {
			if preferIPv6 {
				return 1
			}

			return -1
		}

		if preferIPv6 {
			return -1
		}

		return 1
	})
}
0707010000002B000081A4000000000000000000000001650C59210000035D000000000000000000000000000000000000003900000000dnsproxy-0.55.0/internal/netutil/netutil_example_test.gopackage netutil_test

import (
	"fmt"
	"net"

	"github.com/AdguardTeam/dnsproxy/internal/netutil"
)

func ExampleSortIPAddrs() {
	printAddrs := func(header string, addrs []net.IPAddr) {
		fmt.Printf("%s:\n", header)
		for i, a := range addrs {
			fmt.Printf("%d: %s\n", i+1, a.IP)
		}

		fmt.Println()
	}

	addrs := []net.IPAddr{{
		IP: net.ParseIP("1.2.3.4"),
	}, {
		IP: net.ParseIP("1.2.3.5"),
	}, {
		IP: net.ParseIP("2a00::1234"),
	}, {
		IP: net.ParseIP("2a00::1235"),
	}, {
		IP: nil,
	}}
	netutil.SortIPAddrs(addrs, false)
	printAddrs("IPv4 preferred", addrs)

	netutil.SortIPAddrs(addrs, true)
	printAddrs("IPv6 preferred", addrs)

	// Output:
	//
	// IPv4 preferred:
	// 1: 1.2.3.4
	// 2: 1.2.3.5
	// 3: 2a00::1234
	// 4: 2a00::1235
	// 5: <nil>
	//
	// IPv6 preferred:
	// 1: 2a00::1234
	// 2: 2a00::1235
	// 3: 1.2.3.4
	// 4: 1.2.3.5
	// 5: <nil>
}
0707010000002C000081A4000000000000000000000001650C59210000081D000000000000000000000000000000000000003100000000dnsproxy-0.55.0/internal/netutil/netutil_test.gopackage netutil_test

import (
	"net/netip"
	"testing"

	"github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/stretchr/testify/assert"
	"golang.org/x/exp/slices"
)

func TestSortNetIPAddrs(t *testing.T) {
	var (
		aIPv4    = netip.MustParseAddr("1.2.3.4")
		bIPv4    = netip.MustParseAddr("4.3.2.1")
		aIPv6    = netip.MustParseAddr("2a00::1234")
		bIPv6    = netip.MustParseAddr("2a00::4321")
		badIP, _ = netip.ParseAddr("bad")
	)

	testCases := []struct {
		name       string
		addrs      []netip.Addr
		want       []netip.Addr
		preferIPv6 bool
	}{{
		name:       "v4_preferred",
		addrs:      []netip.Addr{aIPv6, bIPv6, badIP, aIPv4, bIPv4},
		want:       []netip.Addr{aIPv4, bIPv4, aIPv6, bIPv6, badIP},
		preferIPv6: false,
	}, {
		name:       "v6_preferred",
		addrs:      []netip.Addr{aIPv4, bIPv4, badIP, aIPv6, bIPv6},
		want:       []netip.Addr{aIPv6, bIPv6, aIPv4, bIPv4, badIP},
		preferIPv6: true,
	}, {
		name:       "shuffled_v4_preferred",
		addrs:      []netip.Addr{badIP, aIPv4, bIPv6, aIPv6, bIPv4},
		want:       []netip.Addr{aIPv4, bIPv4, aIPv6, bIPv6, badIP},
		preferIPv6: false,
	}, {
		name:       "shuffled_v6_preferred",
		addrs:      []netip.Addr{badIP, aIPv4, bIPv6, aIPv6, bIPv4},
		want:       []netip.Addr{aIPv6, bIPv6, aIPv4, bIPv4, badIP},
		preferIPv6: true,
	}, {
		name:       "empty",
		addrs:      []netip.Addr{},
		want:       []netip.Addr{},
		preferIPv6: false,
	}, {
		name:       "single",
		addrs:      []netip.Addr{aIPv4},
		want:       []netip.Addr{aIPv4},
		preferIPv6: false,
	}, {
		name:       "start_with_ipv4",
		addrs:      []netip.Addr{aIPv4, aIPv6, bIPv4, bIPv6},
		want:       []netip.Addr{aIPv6, bIPv6, aIPv4, bIPv4},
		preferIPv6: true,
	}, {
		name:       "start_with_ipv6",
		addrs:      []netip.Addr{aIPv6, aIPv4, bIPv6, bIPv4},
		want:       []netip.Addr{aIPv6, bIPv6, aIPv4, bIPv4},
		preferIPv6: true,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			ips := slices.Clone(tc.addrs)
			netutil.SortNetIPAddrs(ips, tc.preferIPv6)
			assert.Equal(t, tc.want, ips)
		})
	}
}
0707010000002D000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000002A00000000dnsproxy-0.55.0/internal/netutil/testdata0707010000002E000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000003400000000dnsproxy-0.55.0/internal/netutil/testdata/TestHosts0707010000002F000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000003D00000000dnsproxy-0.55.0/internal/netutil/testdata/TestHosts/bad_file07070100000030000081A4000000000000000000000001650C5921000000B6000000000000000000000000000000000000004300000000dnsproxy-0.55.0/internal/netutil/testdata/TestHosts/bad_file/hosts# comment about the following empty line

# comment about the above empty line

1.2.3.256 a.b # invalid address
1.2.3.4 a.123 # invalid top-level domain
1.2.3.4 .a.b  # empty domain
07070100000031000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000003E00000000dnsproxy-0.55.0/internal/netutil/testdata/TestHosts/good_file07070100000032000081A4000000000000000000000001650C59210000033A000000000000000000000000000000000000004400000000dnsproxy-0.55.0/internal/netutil/testdata/TestHosts/good_file/hosts# IPv4

# 1st host.
0.0.0.1 Host.One

# 2nd host.
0.0.0.2 Host.Two

# 1st host full duplicate.
0.0.0.1 host.one

# 2nd host duplicate with new name.
0.0.0.2 host.two Host.New

# 1st host with foreign name.
0.0.0.1 host.new

# 2nd host new name.
0.0.0.2 Again.Host.Two

# Mapped

# 1st host.
::ffff:0.0.0.1 Host.One

# 2nd host.
::ffff:0.0.0.2 Host.Two

# 1st host full duplicate.
::ffff:0.0.0.1 host.one

# 2nd host duplicate with new name.
::ffff:0.0.0.2 host.two Host.New

# 1st host with foreign name.
::ffff:0.0.0.1 host.new

# 2nd host new name.
::ffff:0.0.0.2 Again.Host.Two

# IPv6

# 1st host.
::1 Host.One

# 2nd host.
::2 Host.Two

# 1st host full duplicate.
::1 host.one

# 2nd host duplicate with new name.
::2 host.two Host.New

# 1st host with foreign name.
::1 host.new

# 2nd host new name.
::2 Again.Host.Two
07070100000033000081A4000000000000000000000001650C5921000003AC000000000000000000000000000000000000002800000000dnsproxy-0.55.0/internal/netutil/udp.gopackage netutil

import "net"

// UDPGetOOBSize returns maximum size of the received OOB data.
func UDPGetOOBSize() (oobSize int) {
	return udpGetOOBSize()
}

// UDPSetOptions sets flag options on a UDP socket to be able to receive the
// necessary OOB data.
func UDPSetOptions(c *net.UDPConn) (err error) {
	return udpSetOptions(c)
}

// UDPRead reads the message from conn using buf and receives a control-message
// payload of size udpOOBSize from it.  It returns the number of bytes copied
// into buf and the source address of the message.
func UDPRead(
	conn *net.UDPConn,
	buf []byte,
	udpOOBSize int,
) (n int, localIP net.IP, remoteAddr *net.UDPAddr, err error) {
	return udpRead(conn, buf, udpOOBSize)
}

// UDPWrite writes the data to the remoteAddr using conn.
func UDPWrite(
	data []byte,
	conn *net.UDPConn,
	remoteAddr *net.UDPAddr,
	localIP net.IP,
) (n int, err error) {
	return udpWrite(data, conn, remoteAddr, localIP)
}
07070100000034000081A4000000000000000000000001650C592100000742000000000000000000000000000000000000002D00000000dnsproxy-0.55.0/internal/netutil/udp_unix.go//go:build unix

package netutil

import (
	"fmt"
	"net"

	"github.com/AdguardTeam/golibs/mathutil"
	"golang.org/x/net/ipv4"
	"golang.org/x/net/ipv6"
)

// These are the set of socket option flags for configuring an IPv[46] UDP
// connection to receive an appropriate OOB data.  For both versions the flags
// are:
//
//   - FlagDst
//   - FlagInterface
const (
	ipv4Flags ipv4.ControlFlags = ipv4.FlagDst | ipv4.FlagInterface
	ipv6Flags ipv6.ControlFlags = ipv6.FlagDst | ipv6.FlagInterface
)

// udpGetOOBSize obtains the destination IP from OOB data.
func udpGetOOBSize() (oobSize int) {
	l4, l6 := len(ipv4.NewControlMessage(ipv4Flags)), len(ipv6.NewControlMessage(ipv6Flags))

	return mathutil.Max(l4, l6)
}

func udpSetOptions(c *net.UDPConn) (err error) {
	err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6Flags, true)
	err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4Flags, true)
	if err6 != nil && err4 != nil {
		return fmt.Errorf("failed to call SetControlMessage: ipv4: %v; ipv6: %v", err4, err6)
	}

	return nil
}

func udpGetDstFromOOB(oob []byte) (dst net.IP) {
	cm6 := &ipv6.ControlMessage{}
	if cm6.Parse(oob) == nil && cm6.Dst != nil {
		return cm6.Dst
	}

	cm4 := &ipv4.ControlMessage{}
	if cm4.Parse(oob) == nil && cm4.Dst != nil {
		return cm4.Dst
	}

	return nil
}

func udpRead(
	c *net.UDPConn,
	buf []byte,
	udpOOBSize int,
) (n int, localIP net.IP, remoteAddr *net.UDPAddr, err error) {
	var oobn int
	oob := make([]byte, udpOOBSize)
	n, oobn, _, remoteAddr, err = c.ReadMsgUDP(buf, oob)
	if err != nil {
		return -1, nil, nil, err
	}

	localIP = udpGetDstFromOOB(oob[:oobn])

	return n, localIP, remoteAddr, nil
}

func udpWrite(
	data []byte,
	conn *net.UDPConn,
	remoteAddr *net.UDPAddr,
	localIP net.IP,
) (n int, err error) {
	n, _, err = conn.WriteMsgUDP(data, udpMakeOOBWithSrc(localIP), remoteAddr)

	return n, err
}
07070100000035000081A4000000000000000000000001650C59210000020B000000000000000000000000000000000000003000000000dnsproxy-0.55.0/internal/netutil/udp_windows.go//go:build windows

package netutil

import (
	"net"
)

func udpGetOOBSize() int {
	return 0
}

func udpSetOptions(c *net.UDPConn) error {
	return nil
}

func udpRead(c *net.UDPConn, buf []byte, _ int) (int, net.IP, *net.UDPAddr, error) {
	n, addr, err := c.ReadFrom(buf)
	var udpAddr *net.UDPAddr
	if addr != nil {
		udpAddr = addr.(*net.UDPAddr)
	}

	return n, nil, udpAddr, err
}

func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, _ net.IP) (int, error) {
	return conn.WriteTo(bytes, remoteAddr)
}
07070100000036000081A4000000000000000000000001650C59210000026E000000000000000000000000000000000000003200000000dnsproxy-0.55.0/internal/netutil/udpoob_darwin.go//go:build darwin

package netutil

import (
	"net"

	"golang.org/x/net/ipv6"
)

// udpMakeOOBWithSrc makes the OOB data with the specified source IP.
func udpMakeOOBWithSrc(ip net.IP) (b []byte) {
	if ip4 := ip.To4(); ip4 != nil {
		// Do not set the IPv4 source address via OOB, because it can cause the
		// address to become unspecified on darwin.
		//
		// See https://github.com/AdguardTeam/AdGuardHome/issues/2807.
		//
		// TODO(e.burkov): Develop a workaround to make it write OOB only when
		// listening on an unspecified address.
		return []byte{}
	}

	return (&ipv6.ControlMessage{
		Src: ip,
	}).Marshal()
}
07070100000037000081A4000000000000000000000001650C59210000017B000000000000000000000000000000000000003200000000dnsproxy-0.55.0/internal/netutil/udpoob_others.go//go:build !darwin

package netutil

import (
	"net"

	"golang.org/x/net/ipv4"
	"golang.org/x/net/ipv6"
)

// udpMakeOOBWithSrc makes the OOB data with the specified source IP.
func udpMakeOOBWithSrc(ip net.IP) (b []byte) {
	if ip4 := ip.To4(); ip4 != nil {
		return (&ipv4.ControlMessage{
			Src: ip,
		}).Marshal()
	}

	return (&ipv6.ControlMessage{
		Src: ip,
	}).Marshal()
}
07070100000038000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001F00000000dnsproxy-0.55.0/internal/tools07070100000039000081A4000000000000000000000001650C5921000000E5000000000000000000000000000000000000002600000000dnsproxy-0.55.0/internal/tools/doc.go// Package tools and its main module are a nested internal module containing our
// development tool dependencies.
//
// See https://github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module.
package tools
0707010000003A000081A4000000000000000000000001650C5921000004B3000000000000000000000000000000000000002600000000dnsproxy-0.55.0/internal/tools/go.modmodule github.com/AdguardTeam/dnsproxy/internal/tools

go 1.20

require (
	github.com/fzipp/gocyclo v0.6.0
	github.com/golangci/misspell v0.4.1
	github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601
	github.com/kisielk/errcheck v1.6.3
	github.com/kyoh86/looppointer v0.2.1
	github.com/securego/gosec/v2 v2.16.0
	github.com/uudashr/gocognit v1.0.7
	golang.org/x/tools v0.12.0
	golang.org/x/vuln v1.0.0
	honnef.co/go/tools v0.4.3
	mvdan.cc/gofumpt v0.5.0
	mvdan.cc/unparam v0.0.0-20230815095028-f7c6fb1088f0
)

require (
	github.com/BurntSushi/toml v1.3.2 // indirect
	github.com/google/go-cmp v0.5.9 // indirect
	github.com/google/uuid v1.3.0 // indirect
	github.com/gookit/color v1.5.4 // indirect
	github.com/kyoh86/nolint v0.0.1 // indirect
	github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect
	github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
	golang.org/x/exp v0.0.0-20230307190834-24139beb5833 // indirect
	golang.org/x/exp/typeparams v0.0.0-20230811145659-89c5cff77bcb // indirect
	golang.org/x/mod v0.12.0 // indirect
	golang.org/x/sync v0.3.0 // indirect
	golang.org/x/sys v0.11.0 // indirect
	gopkg.in/yaml.v3 v3.0.1 // indirect
)
0707010000003B000081A4000000000000000000000001650C5921000029FE000000000000000000000000000000000000002600000000dnsproxy-0.55.0/internal/tools/go.sumgithub.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8=
github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo=
github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/golangci/misspell v0.4.1 h1:+y73iSicVy2PqyX7kmUefHusENlrP9YwuHZHPLGQj/g=
github.com/golangci/misspell v0.4.1/go.mod h1:9mAN1quEo3DlpbaIKKyEvRxK1pwqR9s/Sea1bJCtlNI=
github.com/google/go-cmdtest v0.4.1-0.20220921163831-55ab3332a786 h1:rcv+Ippz6RAtvaGgKxc+8FQIpxHgsF+HBzPyYL2cyVU=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601 h1:mrEEilTAUmaAORhssPPkxj84TsHrPMLBGW2Z4SoTxm8=
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
github.com/kisielk/errcheck v1.6.3 h1:dEKh+GLHcWm2oN34nMvDzn1sqI0i0WxPvrgiJA5JuM8=
github.com/kisielk/errcheck v1.6.3/go.mod h1:nXw/i/MfnvRHqXa7XXmQMUB0oNFGuBrNI8d8NLy0LPw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kyoh86/looppointer v0.2.1 h1:Jx9fnkBj/JrIryBLMTYNTj9rvc2SrPS98Dg0w7fxdJg=
github.com/kyoh86/looppointer v0.2.1/go.mod h1:q358WcM8cMWU+5vzqukvaZtnJi1kw/MpRHQm3xvTrjw=
github.com/kyoh86/nolint v0.0.1 h1:GjNxDEkVn2wAxKHtP7iNTrRxytRZ1wXxLV5j4XzGfRU=
github.com/kyoh86/nolint v0.0.1/go.mod h1:1ZiZZ7qqrZ9dZegU96phwVcdQOMKIqRzFJL3ewq9gtI=
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 h1:4kuARK6Y6FxaNu/BnU2OAaLF86eTVhP2hjTB6iMvItA=
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354/go.mod h1:KSVJerMDfblTH7p5MZaTt+8zaT2iEk3AkVb9PQdZuE8=
github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/securego/gosec/v2 v2.16.0 h1:Pi0JKoasQQ3NnoRao/ww/N/XdynIB9NRYYZT5CyOs5U=
github.com/securego/gosec/v2 v2.16.0/go.mod h1:xvLcVZqUfo4aAQu56TNv7/Ltz6emAOQAEsrZrt7uGlI=
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/uudashr/gocognit v1.0.7 h1:e9aFXgKgUJrQ5+bs61zBigmj7bFJ/5cC6HmMahVzuDo=
github.com/uudashr/gocognit v1.0.7/go.mod h1:nAIUuVBnYU7pcninia3BHOvQkpQCeO76Uscky5BOwcY=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKMIXrEpFxNMs0spT3/5s=
golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp/typeparams v0.0.0-20230811145659-89c5cff77bcb h1:v3JOchFBzuOEFQgVl0t5JnLg3yx29q2e1IjrEovxAt4=
golang.org/x/exp/typeparams v0.0.0-20230811145659-89c5cff77bcb/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220702020025-31831981b65f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4=
golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss=
golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM=
golang.org/x/vuln v1.0.0 h1:tYLAU3jD9LQr98Y+3el06lWyGMCnvzw06PIWP3LIy7g=
golang.org/x/vuln v1.0.0/go.mod h1:V0eyhHwaAaHrt42J9bgrN6rd12f6GU4T0Lu0ex2wDg4=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.4.3 h1:o/n5/K5gXqk8Gozvs2cnL0F2S1/g1vcGCAx2vETjITw=
honnef.co/go/tools v0.4.3/go.mod h1:36ZgoUOrqOk1GxwHhyryEkq8FQWkUO2xGuSMhUCcdvA=
mvdan.cc/gofumpt v0.5.0 h1:0EQ+Z56k8tXjj/6TQD25BFNKQXpCvT0rnansIc7Ug5E=
mvdan.cc/gofumpt v0.5.0/go.mod h1:HBeVDtMKRZpXyxFciAirzdKklDlGu8aAy1wEbH5Y9js=
mvdan.cc/unparam v0.0.0-20230815095028-f7c6fb1088f0 h1:NAENkqZ+Xofhqs4R4Af+i3HpZj1M23SFn/lHfRh1D4E=
mvdan.cc/unparam v0.0.0-20230815095028-f7c6fb1088f0/go.mod h1:flQN1deud3vIpPdF88533Lpp/MvzGLgPIPjB1kgBf4I=
0707010000003C000081A4000000000000000000000001650C592100000242000000000000000000000000000000000000002800000000dnsproxy-0.55.0/internal/tools/tools.go//go:build tools

package tools

import (
	_ "github.com/fzipp/gocyclo/cmd/gocyclo"
	_ "github.com/golangci/misspell/cmd/misspell"
	_ "github.com/gordonklaus/ineffassign"
	_ "github.com/kisielk/errcheck"
	_ "github.com/kyoh86/looppointer"
	_ "github.com/securego/gosec/v2/cmd/gosec"
	_ "github.com/uudashr/gocognit/cmd/gocognit"
	_ "golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness"
	_ "golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow"
	_ "golang.org/x/vuln/cmd/govulncheck"
	_ "honnef.co/go/tools/cmd/staticcheck"
	_ "mvdan.cc/gofumpt"
	_ "mvdan.cc/unparam"
)
0707010000003D000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000002100000000dnsproxy-0.55.0/internal/version0707010000003E000081A4000000000000000000000001650C592100000351000000000000000000000000000000000000002C00000000dnsproxy-0.55.0/internal/version/version.go// Package version contains dnsproxy version information.
package version

// Versions

// These are set by the linker.  Unfortunately, we cannot set constants during
// linking, and Go doesn't have a concept of immutable variables, so to be
// thorough we have to only export them through getters.
var (
	branch     string
	committime string
	revision   string
	version    string
)

// Branch returns the compiled-in value of the Git branch.
func Branch() (b string) {
	return branch
}

// CommitTime returns the compiled-in value of the build time as a string.
func CommitTime() (t string) {
	return committime
}

// Revision returns the compiled-in value of the Git revision.
func Revision() (r string) {
	return revision
}

// Version returns the compiled-in value of the build version as a string.
func Version() (v string) {
	return version
}
0707010000003F000081A4000000000000000000000001650C592100005EC6000000000000000000000000000000000000001800000000dnsproxy-0.55.0/main.go// Package main is responsible for command-line interface of dnsproxy.
package main

import (
	"crypto/tls"
	"fmt"
	"net"
	"net/http"
	"net/http/pprof"
	"net/netip"
	"os"
	"os/signal"
	"strings"
	"syscall"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/version"
	"github.com/AdguardTeam/dnsproxy/proxy"
	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/mathutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/timeutil"
	"github.com/ameshkov/dnscrypt/v2"
	goFlags "github.com/jessevdk/go-flags"
	"gopkg.in/yaml.v3"
)

// Options represents console arguments.  For further additions, please do not
// use the default option since it will cause some problems when config files
// are used.
//
// TODO(a.garipov): Consider extracting conf blocks for better fieldalignment.
type Options struct {
	// Configuration file path (yaml), the config path should be read without
	// using goFlags in order not to have default values overriding yaml
	// options.
	ConfigPath string `long:"config-path" description:"yaml configuration file. Minimal working configuration in config.yaml.dist. Options passed through command line will override the ones from this file." default:""`

	// Log settings
	// --

	// Should we write
	Verbose bool `yaml:"verbose" short:"v" long:"verbose" description:"Verbose output (optional)" optional:"yes" optional-value:"true"`

	// Path to a log file
	LogOutput string `yaml:"output" short:"o" long:"output" description:"Path to the log file. If not set, write to stdout."`

	// Listen addrs
	// --

	// Server listen address
	ListenAddrs []string `yaml:"listen-addrs" short:"l" long:"listen" description:"Listening addresses"`

	// Server listen ports
	ListenPorts []int `yaml:"listen-ports" short:"p" long:"port" description:"Listening ports. Zero value disables TCP and UDP listeners"`

	// HTTPS listen ports
	HTTPSListenPorts []int `yaml:"https-port" short:"s" long:"https-port" description:"Listening ports for DNS-over-HTTPS"`

	// TLS listen ports
	TLSListenPorts []int `yaml:"tls-port" short:"t" long:"tls-port" description:"Listening ports for DNS-over-TLS"`

	// QUIC listen ports
	QUICListenPorts []int `yaml:"quic-port" short:"q" long:"quic-port" description:"Listening ports for DNS-over-QUIC"`

	// DNSCrypt listen ports
	DNSCryptListenPorts []int `yaml:"dnscrypt-port" short:"y" long:"dnscrypt-port" description:"Listening ports for DNSCrypt"`

	// Encryption config
	// --

	// Path to the .crt with the certificate chain
	TLSCertPath string `yaml:"tls-crt" short:"c" long:"tls-crt" description:"Path to a file with the certificate chain"`

	// Path to the file with the private key
	TLSKeyPath string `yaml:"tls-key" short:"k" long:"tls-key" description:"Path to a file with the private key"`

	// Minimum TLS version
	TLSMinVersion float32 `yaml:"tls-min-version" long:"tls-min-version" description:"Minimum TLS version, for example 1.0" optional:"yes"`

	// Maximum TLS version
	TLSMaxVersion float32 `yaml:"tls-max-version" long:"tls-max-version" description:"Maximum TLS version, for example 1.3" optional:"yes"`

	// Disable TLS certificate verification
	Insecure bool `yaml:"insecure" long:"insecure" description:"Disable secure TLS certificate validation" optional:"yes" optional-value:"false"`

	// Path to the DNSCrypt configuration file
	DNSCryptConfigPath string `yaml:"dnscrypt-config" short:"g" long:"dnscrypt-config" description:"Path to a file with DNSCrypt configuration. You can generate one using https://github.com/ameshkov/dnscrypt"`

	// HTTP3 controls whether HTTP/3 is enabled for this instance of dnsproxy.
	// It enables HTTP/3 support for both the DoH upstreams and the DoH server.
	HTTP3 bool `yaml:"http3" long:"http3" description:"Enable HTTP/3 support" optional:"yes" optional-value:"false"`

	// Upstream DNS servers settings
	// --

	// DNS upstreams
	Upstreams []string `yaml:"upstream" short:"u" long:"upstream" description:"An upstream to be used (can be specified multiple times). You can also specify path to a file with the list of servers" optional:"false"`

	// Bootstrap DNS
	BootstrapDNS []string `yaml:"bootstrap" short:"b" long:"bootstrap" description:"Bootstrap DNS for DoH and DoT, can be specified multiple times (default: use system-provided)"`

	// Fallback DNS resolver
	Fallbacks []string `yaml:"fallback" short:"f" long:"fallback" description:"Fallback resolvers to use when regular ones are unavailable, can be specified multiple times. You can also specify path to a file with the list of servers"`

	// PrivateRDNSUpstreams are upstreams to use for reverse DNS lookups of
	// private addresses.
	PrivateRDNSUpstreams []string `yaml:"private-rdns-upstream" long:"private-rdns-upstream" description:"Private DNS upstreams to use for reverse DNS lookups of private addresses, can be specified multiple times"`

	// If true, parallel queries to all configured upstream servers
	AllServers bool `yaml:"all-servers" long:"all-servers" description:"If specified, parallel queries to all configured upstream servers are enabled" optional:"yes" optional-value:"true"`

	// Respond to A or AAAA requests only with the fastest IP address
	//  detected by ICMP response time or TCP connection time
	FastestAddress bool `yaml:"fastest-addr" long:"fastest-addr" description:"Respond to A or AAAA requests only with the fastest IP address" optional:"yes" optional-value:"true"`

	// Timeout for outbound DNS queries to remote upstream servers in a
	// human-readable form.  Default is 10s.
	Timeout timeutil.Duration `yaml:"timeout" long:"timeout" description:"Timeout for outbound DNS queries to remote upstream servers in a human-readable form" default:"10s"`

	// Cache settings
	// --

	// If true, DNS cache is enabled
	Cache bool `yaml:"cache" long:"cache" description:"If specified, DNS cache is enabled" optional:"yes" optional-value:"true"`

	// Cache size value
	CacheSizeBytes int `yaml:"cache-size" long:"cache-size" description:"Cache size (in bytes). Default: 64k"`

	// DNS cache minimum TTL value - overrides record value
	CacheMinTTL uint32 `yaml:"cache-min-ttl" long:"cache-min-ttl" description:"Minimum TTL value for DNS entries, in seconds. Capped at 3600. Artificially extending TTLs should only be done with careful consideration."`

	// DNS cache maximum TTL value - overrides record value
	CacheMaxTTL uint32 `yaml:"cache-max-ttl" long:"cache-max-ttl" description:"Maximum TTL value for DNS entries, in seconds."`

	// CacheOptimistic, if set to true, enables the optimistic DNS cache. That means that cached results will be served even if their cache TTL has already expired.
	CacheOptimistic bool `yaml:"cache-optimistic" long:"cache-optimistic" description:"If specified, optimistic DNS cache is enabled" optional:"yes" optional-value:"true"`

	// Anti-DNS amplification measures
	// --

	// Ratelimit value
	Ratelimit int `yaml:"ratelimit" short:"r" long:"ratelimit" description:"Ratelimit (requests per second)"`

	// If true, refuse ANY requests
	RefuseAny bool `yaml:"refuse-any" long:"refuse-any" description:"If specified, refuse ANY requests" optional:"yes" optional-value:"true"`

	// ECS settings
	// --

	// Use EDNS Client Subnet extension
	EnableEDNSSubnet bool `yaml:"edns" long:"edns" description:"Use EDNS Client Subnet extension" optional:"yes" optional-value:"true"`

	// Use Custom EDNS Client Address
	EDNSAddr string `yaml:"edns-addr" long:"edns-addr" description:"Send EDNS Client Address"`

	// DNS64 settings
	// --

	// Defines whether DNS64 functionality is enabled or not
	DNS64 bool `yaml:"dns64" long:"dns64" description:"If specified, dnsproxy will act as a DNS64 server" optional:"yes" optional-value:"true"`

	// DNS64Prefix defines the DNS64 prefixes that dnsproxy should use when it
	// acts as a DNS64 server.  If not specified, dnsproxy uses the default
	// Well-Known Prefix.  This option can be specified multiple times.
	DNS64Prefix []string `yaml:"dns64-prefix" long:"dns64-prefix" description:"Prefix used to handle DNS64. If not specified, dnsproxy uses the 'Well-Known Prefix' 64:ff9b::.  Can be specified multiple times" required:"false"`

	// Other settings and options
	// --

	// Set Server header for the HTTPS server
	HTTPSServerName string `yaml:"https-server-name" long:"https-server-name" description:"Set the Server header for the responses from the HTTPS server." default:"dnsproxy"`

	// If true, all AAAA requests will be replied with NoError RCode and empty answer
	IPv6Disabled bool `yaml:"ipv6-disabled" long:"ipv6-disabled" description:"If specified, all AAAA requests will be replied with NoError RCode and empty answer" optional:"yes" optional-value:"true"`

	// Transform responses that contain at least one of the given IP addresses into NXDOMAIN
	BogusNXDomain []string `yaml:"bogus-nxdomain" long:"bogus-nxdomain" description:"Transform the responses containing at least a single IP that matches specified addresses and CIDRs into NXDOMAIN.  Can be specified multiple times."`

	// UDP buffer size value
	UDPBufferSize int `yaml:"udp-buf-size" long:"udp-buf-size" description:"Set the size of the UDP buffer in bytes. A value <= 0 will use the system default."`

	// The maximum number of go routines
	MaxGoRoutines int `yaml:"max-go-routines" long:"max-go-routines" description:"Set the maximum number of go routines. A value <= 0 will not not set a maximum."`

	// Pprof defines whether the pprof information needs to be exposed via
	// localhost:6060 or not.
	Pprof bool `yaml:"pprof" long:"pprof" description:"If present, exposes pprof information on localhost:6060." optional:"yes" optional-value:"true"`

	// Print DNSProxy version (just for the help)
	Version bool `yaml:"version" long:"version" description:"Prints the program version"`
}

const (
	defaultLocalTimeout = 1 * time.Second
)

func main() {
	options := &Options{}

	for _, arg := range os.Args {
		if arg == "--version" {
			fmt.Printf("dnsproxy version: %s\n", version.Version())

			os.Exit(0)
		}

		// TODO(e.burkov, a.garipov):  Use flag package and remove the manual
		// options parsing.
		//
		// See https://github.com/AdguardTeam/dnsproxy/issues/182.
		if len(arg) > 13 {
			if arg[:13] == "--config-path" {
				fmt.Printf("Path: %s\n", arg[14:])
				b, err := os.ReadFile(arg[14:])
				if err != nil {
					log.Fatalf("failed to read the config file %s: %v", arg[14:], err)
				}
				err = yaml.Unmarshal(b, options)
				if err != nil {
					log.Fatalf("failed to unmarshal the config file %s: %v", arg[14:], err)
				}
			}
		}
	}

	parser := goFlags.NewParser(options, goFlags.Default)
	_, err := parser.Parse()
	if err != nil {
		if flagsErr, ok := err.(*goFlags.Error); ok && flagsErr.Type == goFlags.ErrHelp {
			os.Exit(0)
		}

		os.Exit(1)
	}

	run(options)
}

func run(options *Options) {
	if options.Verbose {
		log.SetLevel(log.DEBUG)
	}
	if options.LogOutput != "" {
		// #nosec G302 -- Trust the file path that is given in the
		// configuration.
		file, err := os.OpenFile(options.LogOutput, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
		if err != nil {
			log.Fatalf("cannot create a log file: %s", err)
		}

		defer func() { _ = file.Close() }()
		log.SetOutput(file)
	}

	runPprof(options)

	log.Info("Starting dnsproxy %s", version.Version())

	// Prepare the proxy server and its configuration.
	config := createProxyConfig(options)
	dnsProxy := &proxy.Proxy{Config: config}

	// Add extra handler if needed.
	if options.IPv6Disabled {
		ipv6Configuration := ipv6Configuration{ipv6Disabled: options.IPv6Disabled}
		dnsProxy.RequestHandler = ipv6Configuration.handleDNSRequest
	}

	// Start the proxy server.
	err := dnsProxy.Start()
	if err != nil {
		log.Fatalf("cannot start the DNS proxy due to %s", err)
	}

	signalChannel := make(chan os.Signal, 1)
	signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM)
	<-signalChannel

	// Stopping the proxy.
	err = dnsProxy.Stop()
	if err != nil {
		log.Fatalf("cannot stop the DNS proxy due to %s", err)
	}
}

// runPprof runs pprof server on localhost:6060 if it's enabled in the options.
func runPprof(options *Options) {
	if !options.Pprof {
		return
	}

	mux := http.NewServeMux()
	mux.HandleFunc("/debug/pprof/", pprof.Index)
	mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
	mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
	mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
	mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
	mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
	mux.Handle("/debug/pprof/block", pprof.Handler("block"))
	mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
	mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
	mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
	mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))

	go func() {
		log.Info("pprof: listening on localhost:6060")
		srv := &http.Server{
			Addr:        "localhost:6060",
			ReadTimeout: 60 * time.Second,
			Handler:     mux,
		}
		err := srv.ListenAndServe()
		log.Error("error while running the pprof server: %s", err)
	}()
}

// createProxyConfig creates proxy.Config from the command line arguments
func createProxyConfig(options *Options) proxy.Config {
	// Create the config
	config := proxy.Config{
		Ratelimit:       options.Ratelimit,
		CacheEnabled:    options.Cache,
		CacheSizeBytes:  options.CacheSizeBytes,
		CacheMinTTL:     options.CacheMinTTL,
		CacheMaxTTL:     options.CacheMaxTTL,
		CacheOptimistic: options.CacheOptimistic,
		RefuseAny:       options.RefuseAny,
		HTTP3:           options.HTTP3,
		// TODO(e.burkov):  The following CIDRs are aimed to match any
		// address.  This is not quite proper approach to be used by
		// default so think about configuring it.
		TrustedProxies:         []string{"0.0.0.0/0", "::0/0"},
		EnableEDNSClientSubnet: options.EnableEDNSSubnet,
		UDPBufferSize:          options.UDPBufferSize,
		HTTPSServerName:        options.HTTPSServerName,
		MaxGoroutines:          options.MaxGoRoutines,
	}

	// TODO(e.burkov):  Make these methods of [Options].
	initUpstreams(&config, options)
	initEDNS(&config, options)
	initBogusNXDomain(&config, options)
	initTLSConfig(&config, options)
	initDNSCryptConfig(&config, options)
	initListenAddrs(&config, options)
	initDNS64(&config, options)

	return config
}

// isEmpty returns false if uc contains at least a single upstream.  uc must not
// be nil.
//
// TODO(e.burkov):  Think of a better way to validate the config.  Perhaps,
// return an error from [ParseUpstreamsConfig] if no upstreams were initialized.
func isEmpty(uc *proxy.UpstreamConfig) (ok bool) {
	return len(uc.Upstreams) == 0 &&
		len(uc.DomainReservedUpstreams) == 0 &&
		len(uc.SpecifiedDomainUpstreams) == 0
}

// initUpstreams inits upstream-related config
func initUpstreams(config *proxy.Config, options *Options) {
	// Init upstreams

	httpVersions := upstream.DefaultHTTPVersions
	if options.HTTP3 {
		httpVersions = []upstream.HTTPVersion{
			upstream.HTTPVersion3,
			upstream.HTTPVersion2,
			upstream.HTTPVersion11,
		}
	}

	var err error

	timeout := options.Timeout.Duration
	upsOpts := &upstream.Options{
		HTTPVersions:       httpVersions,
		InsecureSkipVerify: options.Insecure,
		Bootstrap:          options.BootstrapDNS,
		Timeout:            timeout,
	}
	upstreams := loadServersList(options.Upstreams)

	config.UpstreamConfig, err = proxy.ParseUpstreamsConfig(upstreams, upsOpts)
	if err != nil {
		log.Fatalf("error while parsing upstreams configuration: %s", err)
	}

	privUpsOpts := &upstream.Options{
		HTTPVersions: httpVersions,
		Bootstrap:    options.BootstrapDNS,
		Timeout:      mathutil.Min(defaultLocalTimeout, timeout),
	}
	privUpstreams := loadServersList(options.PrivateRDNSUpstreams)

	private, err := proxy.ParseUpstreamsConfig(privUpstreams, privUpsOpts)
	if err != nil {
		log.Fatalf("error while parsing private rdns upstreams configuration: %s", err)
	}
	if !isEmpty(private) {
		config.PrivateRDNSUpstreamConfig = private
	}

	fallbackUpstreams := loadServersList(options.Fallbacks)
	fallbacks, err := proxy.ParseUpstreamsConfig(fallbackUpstreams, upsOpts)
	if err != nil {
		log.Fatalf("error while parsing fallback upstreams configuration: %s", err)
	}

	if !isEmpty(fallbacks) {
		config.Fallbacks = fallbacks
	}

	if options.AllServers {
		config.UpstreamMode = proxy.UModeParallel
	} else if options.FastestAddress {
		config.UpstreamMode = proxy.UModeFastestAddr
	} else {
		config.UpstreamMode = proxy.UModeLoadBalance
	}
}

// initEDNS inits EDNS-related config
func initEDNS(config *proxy.Config, options *Options) {
	if options.EDNSAddr != "" {
		if options.EnableEDNSSubnet {
			ednsIP := net.ParseIP(options.EDNSAddr)
			if ednsIP == nil {
				log.Fatalf("cannot parse %s", options.EDNSAddr)
			}
			config.EDNSAddr = ednsIP
		} else {
			log.Printf("--edns-addr=%s need --edns to work", options.EDNSAddr)
		}
	}
}

// initBogusNXDomain inits BogusNXDomain structure
func initBogusNXDomain(config *proxy.Config, options *Options) {
	if len(options.BogusNXDomain) == 0 {
		return
	}

	for _, s := range options.BogusNXDomain {
		subnet, err := netutil.ParseSubnet(s)
		if err != nil {
			log.Error("%s", err)

			continue
		}

		config.BogusNXDomain = append(config.BogusNXDomain, subnet)
	}
}

// initTLSConfig inits the TLS config
func initTLSConfig(config *proxy.Config, options *Options) {
	if options.TLSCertPath != "" && options.TLSKeyPath != "" {
		tlsConfig, err := newTLSConfig(options)
		if err != nil {
			log.Fatalf("failed to load TLS config: %s", err)
		}
		config.TLSConfig = tlsConfig
	}
}

// initDNSCryptConfig inits the DNSCrypt config
func initDNSCryptConfig(config *proxy.Config, options *Options) {
	if options.DNSCryptConfigPath == "" {
		return
	}

	b, err := os.ReadFile(options.DNSCryptConfigPath)
	if err != nil {
		log.Fatalf("failed to read DNSCrypt config %s: %v", options.DNSCryptConfigPath, err)
	}

	rc := &dnscrypt.ResolverConfig{}
	err = yaml.Unmarshal(b, rc)
	if err != nil {
		log.Fatalf("failed to unmarshal DNSCrypt config: %v", err)
	}

	cert, err := rc.CreateCert()
	if err != nil {
		log.Fatalf("failed to create DNSCrypt certificate: %v", err)
	}

	config.DNSCryptResolverCert = cert
	config.DNSCryptProviderName = rc.ProviderName
}

// initListenAddrs inits listen addrs
func initListenAddrs(config *proxy.Config, options *Options) {
	listenIPs := []net.IP{}

	if len(options.ListenAddrs) == 0 {
		// If ListenAddrs has not been parsed through config file nor command
		// line we set it to "0.0.0.0".
		options.ListenAddrs = []string{"0.0.0.0"}
	}

	if len(options.ListenPorts) == 0 {
		// If ListenPorts has not been parsed through config file nor command
		// line we set it to 53.
		options.ListenPorts = []int{53}
	}

	for _, a := range options.ListenAddrs {
		ip := net.ParseIP(a)
		if ip == nil {
			log.Fatalf("cannot parse %s", a)
		}
		listenIPs = append(listenIPs, ip)
	}

	if len(options.ListenPorts) != 0 && options.ListenPorts[0] != 0 {
		for _, port := range options.ListenPorts {
			for _, ip := range listenIPs {

				ua := &net.UDPAddr{Port: port, IP: ip}
				config.UDPListenAddr = append(config.UDPListenAddr, ua)

				ta := &net.TCPAddr{Port: port, IP: ip}
				config.TCPListenAddr = append(config.TCPListenAddr, ta)
			}
		}
	}

	if config.TLSConfig != nil {
		for _, port := range options.TLSListenPorts {
			for _, ip := range listenIPs {
				a := &net.TCPAddr{Port: port, IP: ip}
				config.TLSListenAddr = append(config.TLSListenAddr, a)
			}
		}

		for _, port := range options.HTTPSListenPorts {
			for _, ip := range listenIPs {
				a := &net.TCPAddr{Port: port, IP: ip}
				config.HTTPSListenAddr = append(config.HTTPSListenAddr, a)
			}
		}

		for _, port := range options.QUICListenPorts {
			for _, ip := range listenIPs {
				a := &net.UDPAddr{Port: port, IP: ip}
				config.QUICListenAddr = append(config.QUICListenAddr, a)
			}
		}
	}

	if config.DNSCryptResolverCert != nil && config.DNSCryptProviderName != "" {
		for _, port := range options.DNSCryptListenPorts {
			for _, ip := range listenIPs {
				tcp := &net.TCPAddr{Port: port, IP: ip}
				config.DNSCryptTCPListenAddr = append(config.DNSCryptTCPListenAddr, tcp)

				udp := &net.UDPAddr{Port: port, IP: ip}
				config.DNSCryptUDPListenAddr = append(config.DNSCryptUDPListenAddr, udp)
			}
		}
	}
}

// initDNS64 sets the DNS64 configuration into conf.
func initDNS64(conf *proxy.Config, options *Options) {
	if conf.UseDNS64 = options.DNS64; !conf.UseDNS64 {
		return
	}

	if conf.PrivateRDNSUpstreamConfig == nil || isEmpty(conf.PrivateRDNSUpstreamConfig) {
		log.Fatalf("at least one private upstream must be configured to use dns64")
	}

	var prefs []netip.Prefix
	for i, p := range options.DNS64Prefix {
		pref, err := netip.ParsePrefix(p)
		if err != nil {
			log.Fatalf("parsing prefix at index %d: %v", i, err)
		}

		prefs = append(prefs, pref)
	}

	conf.DNS64Prefs = prefs
}

// IPv6 configuration
type ipv6Configuration struct {
	ipv6Disabled bool // If true, all AAAA requests will be replied with NoError RCode and empty answer
}

// handleDNSRequest checks IPv6 configuration for current session before resolve
func (c *ipv6Configuration) handleDNSRequest(p *proxy.Proxy, ctx *proxy.DNSContext) error {
	if proxy.CheckDisabledAAAARequest(ctx, c.ipv6Disabled) {
		return nil
	}

	return p.Resolve(ctx)
}

// NewTLSConfig returns a TLS config that includes a certificate
// Use for server TLS config or when using a client certificate
// If caPath is empty, system CAs will be used
func newTLSConfig(options *Options) (*tls.Config, error) {
	// Set default TLS min/max versions
	tlsMinVersion := tls.VersionTLS10 // Default for crypto/tls
	tlsMaxVersion := tls.VersionTLS13 // Default for crypto/tls
	switch options.TLSMinVersion {
	case 1.1:
		tlsMinVersion = tls.VersionTLS11
	case 1.2:
		tlsMinVersion = tls.VersionTLS12
	case 1.3:
		tlsMinVersion = tls.VersionTLS13
	}
	switch options.TLSMaxVersion {
	case 1.0:
		tlsMaxVersion = tls.VersionTLS10
	case 1.1:
		tlsMaxVersion = tls.VersionTLS11
	case 1.2:
		tlsMaxVersion = tls.VersionTLS12
	}

	cert, err := loadX509KeyPair(options.TLSCertPath, options.TLSKeyPath)
	if err != nil {
		return nil, fmt.Errorf("could not load TLS cert: %s", err)
	}

	// #nosec G402 -- TLS MinVersion is configured by user.
	return &tls.Config{
		Certificates: []tls.Certificate{cert},
		MinVersion:   uint16(tlsMinVersion),
		MaxVersion:   uint16(tlsMaxVersion),
	}, nil
}

// loadX509KeyPair reads and parses a public/private key pair from a pair of
// files.  The files must contain PEM encoded data.  The certificate file may
// contain intermediate certificates following the leaf certificate to form a
// certificate chain.  On successful return, Certificate.Leaf will be nil
// because the parsed form of the certificate is not retained.
func loadX509KeyPair(certFile, keyFile string) (crt tls.Certificate, err error) {
	// #nosec G304 -- Trust the file path that is given in the configuration.
	certPEMBlock, err := os.ReadFile(certFile)
	if err != nil {
		return tls.Certificate{}, err
	}

	// #nosec G304 -- Trust the file path that is given in the configuration.
	keyPEMBlock, err := os.ReadFile(keyFile)
	if err != nil {
		return tls.Certificate{}, err
	}

	return tls.X509KeyPair(certPEMBlock, keyPEMBlock)
}

// loadServersList loads a list of DNS servers from the specified list.  The
// thing is that the user may specify either a server address or the path to a
// file with a list of addresses.  This method takes care of it, it reads the
// file and loads servers from this file if needed.
func loadServersList(sources []string) []string {
	var servers []string

	for _, source := range sources {
		// #nosec G304 -- Trust the file path that is given in the
		// configuration.
		data, err := os.ReadFile(source)
		if err != nil {
			// Ignore errors, just consider it a server address and not a file.
			servers = append(servers, source)
		}

		lines := strings.Split(string(data), "\n")
		for _, line := range lines {
			line = strings.TrimSpace(line)

			// Ignore comments in the file.
			if line == "" ||
				strings.HasPrefix(line, "!") ||
				strings.HasPrefix(line, "#") {
				continue
			}

			servers = append(servers, line)
		}
	}

	return servers
}
07070100000040000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001600000000dnsproxy-0.55.0/proxy07070100000041000081A4000000000000000000000001650C59210000035F000000000000000000000000000000000000002700000000dnsproxy-0.55.0/proxy/bogusnxdomain.gopackage proxy

import (
	"net"

	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

// isBogusNXDomain returns true if m contains at least a single IP address in
// the Answer section contained in BogusNXDomain subnets of p.
func (p *Proxy) isBogusNXDomain(m *dns.Msg) (ok bool) {
	if m == nil || len(p.BogusNXDomain) == 0 || len(m.Question) == 0 {
		return false
	} else if qt := m.Question[0].Qtype; qt != dns.TypeA && qt != dns.TypeAAAA {
		return false
	}

	for _, rr := range m.Answer {
		ip := proxyutil.IPFromRR(rr)
		if containsIP(p.BogusNXDomain, ip) {
			return true
		}
	}

	return false
}

func containsIP(nets []*net.IPNet, ip net.IP) (ok bool) {
	if netutil.ValidateIP(ip) != nil {
		return false
	}

	for _, n := range nets {
		if n.Contains(ip) {
			return true
		}
	}

	return false
}
07070100000042000081A4000000000000000000000001650C592100000FDC000000000000000000000000000000000000002C00000000dnsproxy-0.55.0/proxy/bogusnxdomain_test.gopackage proxy

import (
	"net"
	"testing"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestProxy_IsBogusNXDomain(t *testing.T) {
	prx := createTestProxy(t, nil)
	prx.CacheEnabled = true

	prx.BogusNXDomain = []*net.IPNet{{
		IP:   net.IP{4, 3, 2, 1},
		Mask: net.CIDRMask(24, netutil.IPv4BitLen),
	}, {
		IP:   net.IPv4(1, 2, 3, 4),
		Mask: net.IPv4Mask(255, 0, 0, 0),
	}, {
		IP:   net.IP{10, 11, 12, 13},
		Mask: net.CIDRMask(netutil.IPv4BitLen, netutil.IPv4BitLen),
	}, {
		IP:   net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
		Mask: net.CIDRMask(120, netutil.IPv6BitLen),
	}}

	testCases := []struct {
		name      string
		ans       []dns.RR
		wantRcode int
	}{{
		name: "bogus_subnet",
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
			A:   net.ParseIP("4.3.2.1"),
		}},
		wantRcode: dns.RcodeNameError,
	}, {
		name: "bogus_big_subnet",
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
			A:   net.ParseIP("1.254.254.254"),
		}},
		wantRcode: dns.RcodeNameError,
	}, {
		name: "bogus_single_ip",
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
			A:   net.ParseIP("10.11.12.13"),
		}},
		wantRcode: dns.RcodeNameError,
	}, {
		name: "bogus_6",
		ans: []dns.RR{&dns.AAAA{
			Hdr:  dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10},
			AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 99},
		}},
		wantRcode: dns.RcodeNameError,
	}, {
		name: "non-bogus",
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10},
			A:   net.ParseIP("10.11.12.14"),
		}},
		wantRcode: dns.RcodeSuccess,
	}, {
		name: "non-bogus_6",
		ans: []dns.RR{&dns.AAAA{
			Hdr:  dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10},
			AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 15},
		}},
		wantRcode: dns.RcodeSuccess,
	}}

	u := testUpstream{}
	prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u}

	err := prx.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, prx.Stop)

	d := &DNSContext{
		Req: createHostTestMessage("host"),
	}

	for _, tc := range testCases {
		u.ans = tc.ans

		t.Run(tc.name, func(t *testing.T) {
			err = prx.Resolve(d)
			require.NoError(t, err)
			require.NotNil(t, d.Res)

			assert.Equal(t, tc.wantRcode, d.Res.Rcode)
		})
	}
}

func TestContainsIP(t *testing.T) {
	nets := []*net.IPNet{{
		// IPv4.
		IP:   net.IP{1, 2, 3, 0},
		Mask: net.IPv4Mask(255, 255, 255, 0),
	}, {
		// IPv6 from IPv4.
		IP:   net.IPv4(1, 2, 4, 0),
		Mask: net.CIDRMask(16, 32),
	}, {
		// IPv6.
		IP:   net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0},
		Mask: net.CIDRMask(120, net.IPv6len*8),
	}}

	testCases := []struct {
		name string
		want assert.BoolAssertionFunc
		ip   net.IP
	}{{
		name: "ipv4_yes",
		want: assert.True,
		ip:   net.IP{1, 2, 3, 255},
	}, {
		name: "ipv4_6_yes",
		want: assert.True,
		ip:   net.IPv4(1, 2, 4, 254),
	}, {
		name: "ipv6_yes",
		want: assert.True,
		ip:   net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
	}, {
		name: "ipv6_4_yes",
		want: assert.True,
		ip:   net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 1, 2, 3, 0},
	}, {
		name: "ipv4_no",
		want: assert.False,
		ip:   net.IP{2, 1, 3, 255},
	}, {
		name: "ipv4_6_no",
		want: assert.False,
		ip:   net.IPv4(2, 1, 4, 254),
	}, {
		name: "ipv6_no",
		want: assert.False,
		ip:   net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 15},
	}, {
		name: "ipv6_4_no",
		want: assert.False,
		ip:   net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 2, 1, 4, 0},
	}, {
		name: "nil_no",
		want: assert.False,
		ip:   nil,
	}, {
		name: "bad_ip",
		want: assert.False,
		ip:   net.IP{42},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			tc.want(t, containsIP(nets, tc.ip))
		})
	}
}
07070100000043000081A4000000000000000000000001650C592100003EC5000000000000000000000000000000000000001F00000000dnsproxy-0.55.0/proxy/cache.gopackage proxy

import (
	"bytes"
	"encoding/binary"
	"math"
	"net"
	"strings"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	glcache "github.com/AdguardTeam/golibs/cache"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/mathutil"
	"github.com/miekg/dns"
	"golang.org/x/exp/slices"
)

// defaultCacheSize is the size of cache in bytes by default.
const defaultCacheSize = 64 * 1024

// cache is used to cache requests and used upstreams.
type cache struct {
	// itemsLock protects requests cache.
	itemsLock *sync.RWMutex

	// itemsWithSubnetLock protects requests cache.
	itemsWithSubnetLock *sync.RWMutex

	// items is the requests cache.
	items glcache.Cache

	// itemsWithSubnet is the requests cache.
	itemsWithSubnet glcache.Cache

	// optimistic defines if the cache should return expired items and resolve
	// those again.
	optimistic bool
}

// cacheItem is a single cache entry.  It's a helper type to aggregate the
// item-specific logic.
type cacheItem struct {
	// m contains the cached response.
	m *dns.Msg

	// u contains an address of the upstream which resolved m.
	u string

	// ttl is the time-to-live value for the item.  Should be set before calling
	// [cacheItem.pack].
	ttl uint32
}

// respToItem converts the pair of the response and upstream resolved the one
// into item for storing it in cache.
func respToItem(m *dns.Msg, u upstream.Upstream) (item *cacheItem) {
	ttl := cacheTTL(m)
	if ttl == 0 {
		return nil
	}

	upsAddr := ""
	if u != nil {
		upsAddr = u.Address()
	}

	return &cacheItem{
		m:   m,
		u:   upsAddr,
		ttl: ttl,
	}
}

const (
	// packedMsgLenSz is the exact length of byte slice capable to store the
	// length of packed DNS message.  It's essentially the size of a uint16.
	packedMsgLenSz = 2
	// expTimeSz is the exact length of byte slice capable to store the
	// expiration time the response.  It's essentially the size of a uint32.
	expTimeSz = 4

	// minPackedLen is the minimum length of the packed cacheItem.
	minPackedLen = expTimeSz + packedMsgLenSz
)

// pack converts the ci into bytes slice.
func (ci *cacheItem) pack() (packed []byte) {
	pm, _ := ci.m.Pack()
	pmLen := len(pm)
	packed = make([]byte, minPackedLen, minPackedLen+pmLen+len(ci.u))

	// Put expiration time.
	binary.BigEndian.PutUint32(packed, uint32(time.Now().Unix())+ci.ttl)

	// Put the length of the packed message.
	binary.BigEndian.PutUint16(packed[expTimeSz:], uint16(pmLen))

	// Put the packed message itself.
	packed = append(packed, pm...)

	// Put the address of the upstream.
	packed = append(packed, ci.u...)

	return packed
}

// optimisticTTL is the default TTL for expired cached responses in seconds.
const optimisticTTL = 10

// unpackItem converts the data into cacheItem using req as a request message.
// expired is true if the item exists but expired.  The expired cached items are
// only returned if c is optimistic.  req must not be nil.
func (c *cache) unpackItem(data []byte, req *dns.Msg) (ci *cacheItem, expired bool) {
	if len(data) < minPackedLen {
		return nil, false
	}

	b := bytes.NewBuffer(data)
	expire := int64(binary.BigEndian.Uint32(b.Next(expTimeSz)))
	now := time.Now().Unix()
	var ttl uint32
	if expired = expire <= now; expired {
		if !c.optimistic {
			return nil, expired
		}

		ttl = optimisticTTL
	} else {
		ttl = uint32(expire - now)
	}

	l := int(binary.BigEndian.Uint16(b.Next(packedMsgLenSz)))
	if l == 0 {
		return nil, expired
	}

	m := &dns.Msg{}
	if m.Unpack(b.Next(l)) != nil {
		return nil, expired
	}

	res := (&dns.Msg{}).SetRcode(req, m.Rcode)
	res.AuthenticatedData = m.AuthenticatedData
	res.RecursionAvailable = m.RecursionAvailable

	var doBit bool
	if o := req.IsEdns0(); o != nil {
		doBit = o.Do()
	}

	// Don't return OPT records from cache since it's deprecated by RFC 6891.
	// If the request has DO bit set we only remove all the OPT RRs, and also
	// all DNSSEC RRs otherwise.
	filterMsg(res, m, req.AuthenticatedData, doBit, ttl)

	return &cacheItem{
		m: res,
		u: string(b.Next(b.Len())),
	}, expired
}

// initCache initializes cache if it's enabled.
func (p *Proxy) initCache() {
	if !p.CacheEnabled {
		log.Info("dnsproxy: cache: disabled")

		return
	}

	size := p.CacheSizeBytes
	log.Info("dnsproxy: cache: enabled, size %d b", size)

	p.cache = newCache(size, p.EnableEDNSClientSubnet, p.CacheOptimistic)
	p.shortFlighter = newOptimisticResolver(p)
}

// newCache returns a properly initialized cache.
func newCache(size int, withECS, optimistic bool) (c *cache) {
	c = &cache{
		itemsLock:           &sync.RWMutex{},
		itemsWithSubnetLock: &sync.RWMutex{},
		items:               createCache(size),
		optimistic:          optimistic,
	}

	if withECS {
		c.itemsWithSubnet = createCache(size)
	}

	return c
}

// get returns cached item for the req if it's found.  expired is true if the
// item's TTL is expired.  key is the resulting key for req.  It's returned to
// avoid recalculating it afterwards.
func (c *cache) get(req *dns.Msg) (ci *cacheItem, expired bool, key []byte) {
	c.itemsLock.RLock()
	defer c.itemsLock.RUnlock()

	if !canLookUpInCache(c.items, req) {
		return nil, false, nil
	}

	key = msgToKey(req)
	data := c.items.Get(key)
	if data == nil {
		return nil, false, key
	}

	if ci, expired = c.unpackItem(data, req); ci == nil {
		c.items.Del(key)
	}

	return ci, expired, key
}

// getWithSubnet returns cached item for the req if it's found by n.  expired
// is true if the item's TTL is expired.  k is the resulting key for req.  It's
// returned to avoid recalculating it afterwards.
//
// Note that a slow longest-prefix-match algorithm is used, so cache searches
// are performed up to mask+1 times.
func (c *cache) getWithSubnet(req *dns.Msg, n *net.IPNet) (ci *cacheItem, expired bool, k []byte) {
	c.itemsWithSubnetLock.RLock()
	defer c.itemsWithSubnetLock.RUnlock()

	if !canLookUpInCache(c.itemsWithSubnet, req) {
		return nil, false, nil
	}

	ecsIP := n.IP.Mask(n.Mask)
	ipLen := len(ecsIP)
	m, _ := n.Mask.Size()

	k = msgToKeyWithSubnet(req, ecsIP, m)
	data := c.itemsWithSubnet.Get(k)

	// In order to reduce allocations we apply mask on bits level.  As the key
	// k has ecsIP in bytes slice representation, each iteration we can just
	// clear one bit in the end of it by applying the bitmask.
	for bitmask := ^byte(0); m >= 0 && data == nil; m-- {
		// Set mask identification byte in the key.
		k[keyMaskIndex] = byte(m)

		// In case mask is zero, the key doesn't have IP in it.
		if m == 0 {
			k = slices.Delete(k, keyIPIndex, keyIPIndex+ipLen)
			data = c.itemsWithSubnet.Get(k)

			continue
		}

		// Shift or renew bitmask.
		if m%8 == 0 {
			bitmask = ^byte(0)
		} else {
			bitmask <<= 1
		}

		// Clear the last non-zero bit in the byte of the IP address.
		k[keyIPIndex+m/8] &= bitmask

		data = c.itemsWithSubnet.Get(k)
	}

	if data == nil {
		return nil, false, k
	}

	if ci, expired = c.unpackItem(data, req); ci == nil {
		c.itemsWithSubnet.Del(k)
	}

	return ci, expired, k
}

// canLookUpInCache returns true if these parameters could be used to make a
// cache lookup.
func canLookUpInCache(cache glcache.Cache, req *dns.Msg) (ok bool) {
	return cache != nil && req != nil && len(req.Question) == 1
}

// createCache returns new Cache with the given cacheSize.
func createCache(cacheSize int) (glc glcache.Cache) {
	conf := glcache.Config{
		MaxSize:   defaultCacheSize,
		EnableLRU: true,
	}

	if cacheSize > 0 {
		conf.MaxSize = uint(cacheSize)
	}

	return glcache.New(conf)
}

// set tries to add the ci into cache.
func (c *cache) set(m *dns.Msg, u upstream.Upstream) {
	item := respToItem(m, u)
	if item == nil {
		return
	}

	key := msgToKey(m)
	packed := item.pack()

	c.itemsLock.Lock()
	defer c.itemsLock.Unlock()

	c.items.Set(key, packed)
}

// setWithSubnet tries to add the ci into cache with subnet and ip used to
// calculate the key.
func (c *cache) setWithSubnet(m *dns.Msg, u upstream.Upstream, subnet *net.IPNet) {
	item := respToItem(m, u)
	if item == nil {
		return
	}

	pref, _ := subnet.Mask.Size()
	key := msgToKeyWithSubnet(m, subnet.IP.Mask(subnet.Mask), pref)
	packed := item.pack()

	c.itemsWithSubnetLock.Lock()
	defer c.itemsWithSubnetLock.Unlock()

	c.itemsWithSubnet.Set(key, packed)
}

// clearItems empties the simple cache.
func (c *cache) clearItems() {
	c.itemsLock.Lock()
	defer c.itemsLock.Unlock()

	c.items.Clear()
}

// clearItemsWithSubnet empties the subnet cache, if any.
func (c *cache) clearItemsWithSubnet() {
	if c.itemsWithSubnet == nil {
		// ECS disabled, return immediately.
		return
	}

	c.itemsWithSubnetLock.Lock()
	defer c.itemsWithSubnetLock.Unlock()

	c.itemsWithSubnet.Clear()
}

// cacheTTL returns the number of seconds for which m is valid to be cached.
// For negative answers it follows RFC 2308 on how to cache NXDOMAIN and NODATA
// kinds of responses.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-2.1,
// https://datatracker.ietf.org/doc/html/rfc2308#section-2.2.
func cacheTTL(m *dns.Msg) (ttl uint32) {
	switch {
	case m == nil:
		return 0
	case m.Truncated:
		log.Debug("dnsproxy: cache: truncated message; not caching")

		return 0
	case len(m.Question) != 1:
		log.Debug("dnsproxy: cache: message with wrong number of questions; not caching")

		return 0
	default:
		ttl = calculateTTL(m)
		if ttl == 0 {
			log.Debug("dnsproxy: cache: ttl calculated to be 0; not caching")

			return 0
		}
	}

	switch rcode := m.Rcode; rcode {
	case dns.RcodeSuccess:
		if isCacheableSucceded(m) {
			return ttl
		}

		log.Debug("dnsproxy: cache: not a cacheable noerror response; not caching")
	case dns.RcodeNameError:
		if isCacheableNegative(m) {
			return ttl
		}

		log.Debug("dnsproxy: cache: not a cacheable nxdomain response; not caching")
	case dns.RcodeServerFailure:
		return ttl
	default:
		log.Debug("dnsproxy: cache: response code %s; not caching", dns.RcodeToString[rcode])
	}

	return 0
}

// hasIPAns check the m for containing at least one A or AAAA RR in answer
// section.
func hasIPAns(m *dns.Msg) (ok bool) {
	for _, rr := range m.Answer {
		if t := rr.Header().Rrtype; t == dns.TypeA || t == dns.TypeAAAA {
			return true
		}
	}

	return false
}

// isCacheableSucceded returns true if m contains useful data to be cached
// treating it as a successful response.
func isCacheableSucceded(m *dns.Msg) (ok bool) {
	qType := m.Question[0].Qtype

	return (qType != dns.TypeA && qType != dns.TypeAAAA) || hasIPAns(m) || isCacheableNegative(m)
}

// isCacheableNegative returns true if m's header has at least a single SOA RR
// and no NS records so that it can be declared authoritative.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-5 for the
// information on the responses from the authoritative server that should be
// cached by the forwarder.
func isCacheableNegative(m *dns.Msg) (ok bool) {
	for _, rr := range m.Ns {
		switch rr.Header().Rrtype {
		case dns.TypeSOA:
			ok = true
		case dns.TypeNS:
			return false
		default:
			// Go on.
		}
	}

	return ok
}

// ServFailMaxCacheTTL is the maximum time-to-live value for caching
// SERVFAIL responses in seconds.  It's consistent with the upper constraint
// of 5 minutes given by RFC 2308.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-7.1.
const ServFailMaxCacheTTL = 30

// calculateTTL returns the number of seconds for which m could be cached.  It's
// usually the lowest TTL among all m's resource records.  It returns 0 if m
// isn't cacheable according to it's contents.
func calculateTTL(m *dns.Msg) (ttl uint32) {
	// Use the maximum value as a guard value.  If the inner loop is entered,
	// it's going to be rewritten with an actual TTL value that is lower than
	// MaxUint32.  If the inner loop isn't entered, catch that and return zero.
	ttl = math.MaxUint32
	for _, rrset := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
		for _, rr := range rrset {
			ttl = minTTL(rr.Header(), ttl)
			if ttl == 0 {
				return 0
			}
		}
	}

	switch {
	case m.Rcode == dns.RcodeServerFailure && ttl > ServFailMaxCacheTTL:
		return ServFailMaxCacheTTL
	case ttl == math.MaxUint32:
		return 0
	default:
		return ttl
	}
}

// minTTL returns the minimum of h's ttl and the passed ttl.
func minTTL(h *dns.RR_Header, ttl uint32) uint32 {
	switch {
	case h.Rrtype == dns.TypeOPT:
		return ttl
	case h.Ttl < ttl:
		return h.Ttl
	default:
		return ttl
	}
}

// Updates a given TTL to fall within the range specified by the cacheMinTTL and
// cacheMaxTTL settings.
func respectTTLOverrides(ttl, cacheMinTTL, cacheMaxTTL uint32) uint32 {
	if ttl < cacheMinTTL {
		return cacheMinTTL
	}

	if cacheMaxTTL != 0 && ttl > cacheMaxTTL {
		return cacheMaxTTL
	}

	return ttl
}

// msgToKey constructs the cache key from type, class and question's name of m.
func msgToKey(m *dns.Msg) (b []byte) {
	q := m.Question[0]
	name := q.Name
	b = make([]byte, packedMsgLenSz+packedMsgLenSz+len(name))

	// Put QTYPE, QCLASS, and QNAME.
	binary.BigEndian.PutUint16(b, q.Qtype)
	binary.BigEndian.PutUint16(b[packedMsgLenSz:], q.Qclass)
	copy(b[2*packedMsgLenSz:], strings.ToLower(name))

	return b
}

const (
	// keyMaskIndex is the index of the byte with mask ones value.
	keyMaskIndex = 1 + 2*packedMsgLenSz

	// keyIPIndex is the start index of the IP address in the key.
	keyIPIndex = keyMaskIndex + 1
)

// msgToKeyWithSubnet constructs the cache key from DO bit, type, class, subnet
// mask, client's IP address and question's name of m.  ecsIP is expected to be
// masked already.
func msgToKeyWithSubnet(m *dns.Msg, ecsIP net.IP, mask int) (key []byte) {
	q := m.Question[0]
	keyLen := keyIPIndex + len(q.Name)
	masked := mask != 0
	if masked {
		keyLen += len(ecsIP)
	}

	// Initialize the slice.
	key = make([]byte, keyLen)

	// Put DO.
	opt := m.IsEdns0()
	key[0] = mathutil.BoolToNumber[byte](opt != nil && opt.Do())

	// Put Qtype.
	//
	// TODO(d.kolyshev): We should put Qtype in key[1:].
	binary.BigEndian.PutUint16(key[:], q.Qtype)

	// Put Qclass.
	binary.BigEndian.PutUint16(key[1+packedMsgLenSz:], q.Qclass)

	// Add mask.
	key[keyMaskIndex] = uint8(mask)
	k := keyIPIndex
	if masked {
		k += copy(key[keyIPIndex:], ecsIP)
	}

	copy(key[k:], strings.ToLower(q.Name))

	return key
}

// isDNSSEC returns true if r is a DNSSEC RR.  NSEC, NSEC3, DS, DNSKEY and
// RRSIG/SIG are DNSSEC records.
func isDNSSEC(r dns.RR) bool {
	switch r.Header().Rrtype {
	case
		dns.TypeNSEC,
		dns.TypeNSEC3,
		dns.TypeDS,
		dns.TypeRRSIG,
		dns.TypeSIG,
		dns.TypeDNSKEY:
		return true
	default:
		return false
	}
}

// filterRRSlice removes OPT RRs, DNSSEC RRs except the specified type if do is
// false, sets TTL if ttl is not equal to zero and returns the copy of the rrs.
// The except parameter defines RR of which type should not be filtered out.
func filterRRSlice(rrs []dns.RR, do bool, ttl uint32, except uint16) (filtered []dns.RR) {
	rrsLen := len(rrs)
	if rrsLen == 0 {
		return nil
	}

	j := 0
	rs := make([]dns.RR, rrsLen)
	for _, r := range rrs {
		if (!do && isDNSSEC(r) && r.Header().Rrtype != except) || r.Header().Rrtype == dns.TypeOPT {
			continue
		}

		if ttl != 0 {
			r.Header().Ttl = ttl
		}
		rs[j] = dns.Copy(r)
		j++
	}

	return rs[:j]
}

// filterMsg removes OPT RRs, DNSSEC RRs if do is false, sets TTL to ttl if it's
// not equal to 0 and puts the results to appropriate fields of dst.  It also
// filters the AD bit if both ad and do are false.
func filterMsg(dst, m *dns.Msg, ad, do bool, ttl uint32) {
	// As RFC 6840 says, validating resolvers should only set the AD bit when a
	// response both meets the conditions listed in RFC 4035, and the request
	// contained either a set DO bit or a set AD bit.
	dst.AuthenticatedData = dst.AuthenticatedData && (ad || do)

	// It's important to filter out only DNSSEC RRs that aren't explicitly
	// requested.
	//
	// See https://datatracker.ietf.org/doc/html/rfc4035#section-3.2.1 and
	// https://github.com/AdguardTeam/dnsproxy/issues/144.
	dst.Answer = filterRRSlice(m.Answer, do, ttl, m.Question[0].Qtype)
	dst.Ns = filterRRSlice(m.Ns, do, ttl, dns.TypeNone)
	dst.Extra = filterRRSlice(m.Extra, do, ttl, dns.TypeNone)
}
07070100000044000081A4000000000000000000000001650C592100005730000000000000000000000000000000000000002400000000dnsproxy-0.55.0/proxy/cache_test.gopackage proxy

import (
	"net"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"github.com/miekg/dns"
)

// testCacheSize is the maximum size of cache for tests.
const testCacheSize = 4096

const testUpsAddr = "https://upstream.address"

var upstreamWithAddr = &funcUpstream{
	exchangeFunc: func(m *dns.Msg) (resp *dns.Msg, err error) { panic("not implemented") },
	addressFunc:  func() (addr string) { return testUpsAddr },
}

func TestServeCached(t *testing.T) {
	// Prepare the proxy server.
	dnsProxy := createTestProxy(t, nil)
	dnsProxy.CacheEnabled = true // just one request per second is allowed

	// Start listening.
	err := dnsProxy.Start()
	require.NoErrorf(t, err, "cannot start the DNS proxy: %s", err)
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	// Fill the cache.
	reply := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
	}).SetQuestion("google.com.", dns.TypeA)
	reply.SetEdns0(defaultUDPBufSize, false)

	dnsProxy.cache.set(reply, upstreamWithAddr)

	// Create a DNS-over-UDP client connection.
	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{Net: "udp", Timeout: 500 * time.Millisecond}

	// Create a DNS request.
	request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
	request.SetEdns0(defaultUDPBufSize, false)

	r, _, err := client.Exchange(request, addr.String())
	require.NoErrorf(t, err, "error in the first request: %s", err)

	requireEqualMsgs(t, r, reply)
}

func TestCache_expired(t *testing.T) {
	const host = "google.com."

	ans := &dns.A{
		Hdr: dns.RR_Header{
			Name:   host,
			Rrtype: dns.TypeA,
			Class:  dns.ClassINET,
		},
		A: net.IP{8, 8, 8, 8},
	}
	reply := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{ans},
	}).SetQuestion(host, dns.TypeA)

	testCases := []struct {
		name       string
		ttl        uint32
		wantTTL    uint32
		optimistic bool
	}{{
		name:       "realistic_hit",
		ttl:        defaultTestTTL,
		wantTTL:    defaultTestTTL,
		optimistic: false,
	}, {
		name:       "realistic_miss",
		ttl:        0,
		wantTTL:    0,
		optimistic: false,
	}, {
		name:       "optimistic_hit",
		ttl:        defaultTestTTL,
		wantTTL:    defaultTestTTL,
		optimistic: true,
	}, {
		name:       "optimistic_expired",
		ttl:        0,
		wantTTL:    optimisticTTL,
		optimistic: true,
	}}

	testCache := newCache(testCacheSize, false, false)
	for _, tc := range testCases {
		ans.Hdr.Ttl = tc.ttl
		req := (&dns.Msg{}).SetQuestion(host, dns.TypeA)

		t.Run(tc.name, func(t *testing.T) {
			if tc.optimistic {
				testCache.optimistic = true
				t.Cleanup(func() { testCache.optimistic = false })
			}

			key := msgToKey(reply)
			data := (&cacheItem{
				m:   reply,
				u:   testUpsAddr,
				ttl: tc.ttl,
			}).pack()
			testCache.items.Set(key, data)
			t.Cleanup(testCache.items.Clear)

			r, expired, key := testCache.get(req)
			assert.Equal(t, msgToKey(req), key)
			assert.Equal(t, tc.ttl == 0, expired)

			if tc.wantTTL != 0 {
				require.NotNil(t, r)

				assert.Equal(t, tc.wantTTL, r.m.Answer[0].Header().Ttl)
				assert.Equal(t, testUpsAddr, r.u)
			} else {
				require.Nil(t, r)
			}
		})
	}
}

func TestCacheDO(t *testing.T) {
	testCache := newCache(testCacheSize, false, false)

	// Fill the cache.
	reply := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
	}).SetQuestion("google.com.", dns.TypeA)
	reply.SetEdns0(4096, true)

	// Store in cache.
	testCache.set(reply, upstreamWithAddr)

	// Make a request.
	request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)

	t.Run("without_do", func(t *testing.T) {
		ci, expired, key := testCache.get(request)
		assert.False(t, expired)
		assert.Equal(t, msgToKey(request), key)
		assert.NotNil(t, ci)
	})

	t.Run("with_do", func(t *testing.T) {
		reqClone := request.Copy()
		t.Cleanup(func() {
			request = reqClone
		})

		request.SetEdns0(4096, true)

		ci, expired, key := testCache.get(request)
		assert.False(t, expired)
		assert.Equal(t, msgToKey(request), key)

		require.NotNil(t, ci)

		assert.Equal(t, testUpsAddr, ci.u)
	})
}

func TestCacheCNAME(t *testing.T) {
	testCache := newCache(testCacheSize, false, false)

	// Fill the cache
	reply := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{newRR(t, "google.com.", dns.TypeCNAME, 3600, "test.google.com.")},
	}).SetQuestion("google.com.", dns.TypeA)
	testCache.set(reply, upstreamWithAddr)

	// Create a DNS request.
	request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)

	t.Run("no_cnames", func(t *testing.T) {
		r, expired, _ := testCache.get(request)
		assert.Nil(t, r)
		assert.False(t, expired)
	})

	// Now fill the cache with a cacheable CNAME response.
	reply.Answer = append(reply.Answer, newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8}))
	testCache.set(reply, upstreamWithAddr)

	// We are testing that a proper CNAME response gets cached
	t.Run("cnames_exist", func(t *testing.T) {
		r, expired, key := testCache.get(request)
		assert.False(t, expired)
		assert.Equal(t, key, msgToKey(request))

		require.NotNil(t, r)

		assert.Equal(t, testUpsAddr, r.u)
	})
}

func TestCache_uncacheable(t *testing.T) {
	testCache := newCache(testCacheSize, false, false)

	// Create a DNS request.
	request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
	// Fill the cache.
	reply := (&dns.Msg{}).SetRcode(request, dns.RcodeBadAlg)

	// We are testing that SERVFAIL responses aren't cached
	testCache.set(reply, upstreamWithAddr)

	r, expired, _ := testCache.get(request)
	assert.Nil(t, r)
	assert.False(t, expired)
}

func TestCache_concurrent(t *testing.T) {
	testCache := newCache(testCacheSize, false, false)

	hosts := map[string]string{
		dns.Fqdn("yandex.com"):     "213.180.204.62",
		dns.Fqdn("google.com"):     "8.8.8.8",
		dns.Fqdn("www.google.com"): "8.8.4.4",
		dns.Fqdn("youtube.com"):    "173.194.221.198",
		dns.Fqdn("car.ru"):         "37.220.161.35",
		dns.Fqdn("cat.ru"):         "192.56.231.67",
	}

	g := &sync.WaitGroup{}
	g.Add(len(hosts))

	for k, v := range hosts {
		go setAndGetCache(t, testCache, g, k, v)
	}

	g.Wait()
}

func TestCacheExpiration(t *testing.T) {
	dnsProxy := createTestProxy(t, nil)
	dnsProxy.CacheEnabled = true

	err := dnsProxy.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	// Create dns messages with TTL of 1 second.
	rrs := []dns.RR{
		newRR(t, "youtube.com.", dns.TypeA, 1, net.IP{173, 194, 221, 198}),
		newRR(t, "google.com.", dns.TypeA, 1, net.IP{8, 8, 8, 8}),
		newRR(t, "yandex.com.", dns.TypeA, 1, net.IP{213, 180, 204, 62}),
	}
	replies := make([]*dns.Msg, len(rrs))
	for i, rr := range rrs {
		rep := (&dns.Msg{
			MsgHdr: dns.MsgHdr{
				Response: true,
			},
			Answer: []dns.RR{dns.Copy(rr)},
		}).SetQuestion(rr.Header().Name, dns.TypeA)
		dnsProxy.cache.set(rep, upstreamWithAddr)
		replies[i] = rep
	}

	for _, r := range replies {
		ci, expired, key := dnsProxy.cache.get(r)
		require.NotNil(t, ci)

		assert.False(t, expired)
		assert.Equal(t, msgToKey(ci.m), key)

		requireEqualMsgs(t, ci.m, r)
	}

	assert.Eventually(t, func() bool {
		for _, r := range replies {
			if ci, _, _ := dnsProxy.cache.get(r); ci != nil {
				return false
			}
		}

		return true
	}, 1100*time.Millisecond, 100*time.Millisecond)
}

func TestCacheExpirationWithTTLOverride(t *testing.T) {
	dnsProxy := createTestProxy(t, nil)
	dnsProxy.CacheEnabled = true
	dnsProxy.CacheMinTTL = 20
	dnsProxy.CacheMaxTTL = 40

	u := testUpstream{}
	dnsProxy.UpstreamConfig.Upstreams = []upstream.Upstream{&u}

	err := dnsProxy.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	d := &DNSContext{}

	t.Run("replace_min", func(t *testing.T) {
		d.Req = createHostTestMessage("host")
		d.Addr = &net.TCPAddr{}

		u.ans = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Rrtype: dns.TypeA,
				Name:   "host.",
				Ttl:    10,
			},
			A: net.IP{4, 3, 2, 1},
		}}

		err = dnsProxy.Resolve(d)
		require.NoError(t, err)

		ci, expired, key := dnsProxy.cache.get(d.Req)
		assert.False(t, expired)
		assert.Equal(t, msgToKey(d.Req), key)

		require.NotNil(t, ci)
		assert.Equal(t, dnsProxy.CacheMinTTL, ci.m.Answer[0].Header().Ttl)
	})

	t.Run("replace_max", func(t *testing.T) {
		d.Req = createHostTestMessage("host2")
		d.Addr = &net.TCPAddr{}

		u.ans = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Rrtype: dns.TypeA,
				Name:   "host2.",
				Ttl:    60,
			},
			A: net.IP{4, 3, 2, 1},
		}}

		err = dnsProxy.Resolve(d)
		assert.Nil(t, err)

		ci, expired, key := dnsProxy.cache.get(d.Req)
		assert.False(t, expired)
		assert.Equal(t, msgToKey(d.Req), key)

		require.NotNil(t, ci)
		assert.Equal(t, dnsProxy.CacheMaxTTL, ci.m.Answer[0].Header().Ttl)
	})
}

type testEntry struct {
	q string
	a []dns.RR
	t uint16
}

type testCase struct {
	ok require.BoolAssertionFunc
	q  string
	a  []dns.RR
	t  uint16
}

type testCases struct {
	cache []testEntry
	cases []testCase
}

func TestCache(t *testing.T) {
	t.Run("simple", func(t *testing.T) {
		testCases{
			cache: []testEntry{{
				q: "google.com.",
				a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t: dns.TypeA,
			}},
			cases: []testCase{{
				ok: require.True,
				q:  "google.com.",
				a:  []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t:  dns.TypeA,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}},
		}.run(t)
	})

	t.Run("mixed_case", func(t *testing.T) {
		testCases{
			cache: []testEntry{{
				q: "gOOgle.com.",
				a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t: dns.TypeA,
			}},
			cases: []testCase{{
				ok: require.True,
				q:  "gOOgle.com.",
				a:  []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t:  dns.TypeA,
			}, {
				ok: require.True,
				q:  "google.com.",
				a:  []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t:  dns.TypeA,
			}, {
				ok: require.True,
				q:  "GOOGLE.COM.",
				a:  []dns.RR{newRR(t, "google.com.", dns.TypeA, 3600, net.IP{8, 8, 8, 8})},
				t:  dns.TypeA,
			}, {
				q:  "gOOgle.com.",
				t:  dns.TypeMX,
				ok: require.False,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}, {
				ok: require.False,
				q:  "GOOGLE.COM.",
				t:  dns.TypeMX,
			}},
		}.run(t)
	})

	t.Run("zero_ttl", func(t *testing.T) {
		testCases{
			cache: []testEntry{{
				q: "gOOgle.com.",
				a: []dns.RR{newRR(t, "google.com.", dns.TypeA, 0, net.IP{8, 8, 8, 8})},
				t: dns.TypeA,
			}},
			cases: []testCase{{
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeA,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeA,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeA,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}, {
				ok: require.False,
				q:  "google.com.",
				t:  dns.TypeMX,
			}},
		}.run(t)
	})
}

func (tests testCases) run(t *testing.T) {
	testCache := newCache(testCacheSize, false, false)

	for _, res := range tests.cache {
		reply := (&dns.Msg{
			MsgHdr: dns.MsgHdr{
				Response: true,
			},
			Answer: res.a,
		}).SetQuestion(res.q, res.t)
		testCache.set(reply, upstreamWithAddr)
	}

	for _, tc := range tests.cases {
		request := (&dns.Msg{}).SetQuestion(tc.q, tc.t)

		ci, expired, _ := testCache.get(request)
		assert.False(t, expired)
		tc.ok(t, ci != nil)

		if tc.a == nil {
			return
		} else if ci == nil {
			continue
		}

		reply := (&dns.Msg{
			MsgHdr: dns.MsgHdr{
				Response: true,
			},
			Answer: tc.a,
		}).SetQuestion(tc.q, tc.t)

		testCache.set(reply, upstreamWithAddr)

		requireEqualMsgs(t, ci.m, reply)
	}
}

// requireEqualMsgs asserts the messages are equal except their ID, Rdlength, and
// the case of questions.
func requireEqualMsgs(t *testing.T, expected, actual *dns.Msg) {
	t.Helper()

	temp := *expected
	temp.Id = actual.Id

	require.Equal(t, len(temp.Answer), len(actual.Answer))
	for i, ans := range actual.Answer {
		temp.Answer[i].Header().Rdlength = ans.Header().Rdlength
	}
	for _, rr := range actual.Answer {
		if a, ok := rr.(*dns.A); ok {
			if a4 := a.A.To4(); a4 != nil {
				a.A = a4
			}
		}
	}
	for i := range temp.Question {
		temp.Question[i].Name = strings.ToLower(temp.Question[i].Name)
	}
	for i := range actual.Question {
		actual.Question[i].Name = strings.ToLower(actual.Question[i].Name)
	}

	assert.Equal(t, &temp, actual)
}

func setAndGetCache(t *testing.T, c *cache, g *sync.WaitGroup, host, ip string) {
	defer g.Done()

	ipAddr := net.ParseIP(ip)

	dnsMsg := (&dns.Msg{
		MsgHdr: dns.MsgHdr{
			Response: true,
		},
		Answer: []dns.RR{newRR(t, host, dns.TypeA, 1, ipAddr)},
	}).SetQuestion(host, dns.TypeA)

	c.set(dnsMsg, upstreamWithAddr)

	for i := 0; i < 2; i++ {
		ci, expired, key := c.get(dnsMsg)
		require.NotNilf(t, ci, "no cache found for %s", host)

		assert.False(t, expired)
		assert.Equal(t, msgToKey(dnsMsg), key)

		requireEqualMsgs(t, ci.m, dnsMsg)
	}

	assert.Eventuallyf(t, func() bool {
		ci, _, _ := c.get(dnsMsg)

		return ci == nil
	}, 1100*time.Millisecond, 100*time.Millisecond, "cache for %s should already be removed", host)
}

func TestCache_getWithSubnet(t *testing.T) {
	const testFQDN = "example.com."

	ip1234, ip2234, ip3234 := net.IP{1, 2, 3, 4}, net.IP{2, 2, 3, 4}, net.IP{3, 2, 3, 4}
	req := (&dns.Msg{}).SetQuestion(testFQDN, dns.TypeA)
	mask16 := net.CIDRMask(16, netutil.IPv4BitLen)
	mask24 := net.CIDRMask(24, netutil.IPv4BitLen)

	c := newCache(testCacheSize, true, false)

	t.Run("empty", func(t *testing.T) {
		ci, expired, _ := c.getWithSubnet(req, &net.IPNet{IP: ip1234, Mask: mask24})
		assert.Nil(t, ci)
		assert.False(t, expired)
	})

	// Add a response with subnet.
	resp := (&dns.Msg{
		Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 1, net.IP{1, 1, 1, 1})},
	}).SetReply(req)
	c.setWithSubnet(resp, upstreamWithAddr, &net.IPNet{IP: ip1234, Mask: mask16})

	t.Run("different_ip", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip2234, Mask: mask24})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, ip2234, 0), key)
		assert.Nil(t, ci)
	})

	// Add a response entry with subnet #2.
	resp = (&dns.Msg{
		Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 1, net.IP{2, 2, 2, 2})},
	}).SetReply(req)
	c.setWithSubnet(resp, upstreamWithAddr, &net.IPNet{IP: ip2234, Mask: mask16})

	// Add a response entry without subnet.
	resp = (&dns.Msg{
		Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 1, net.IP{3, 3, 3, 3})},
	}).SetReply(req)
	c.setWithSubnet(resp, upstreamWithAddr, &net.IPNet{IP: nil, Mask: nil})

	t.Run("with_subnet_1", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip1234, Mask: mask24})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, ip1234.Mask(mask16), 16), key)

		require.NotNil(t, ci)
		require.NotNil(t, ci.m)
		require.NotEmpty(t, ci.m.Answer)

		a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
		assert.True(t, a.A.Equal(net.IP{1, 1, 1, 1}))
	})

	t.Run("with_subnet_2", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip2234, Mask: mask24})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, ip2234.Mask(mask16), 16), key)

		require.NotNil(t, ci)
		require.NotNil(t, ci.m)
		require.NotEmpty(t, ci.m.Answer)

		a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
		assert.True(t, a.A.Equal(net.IP{2, 2, 2, 2}))
	})

	t.Run("with_subnet_3", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{IP: ip3234, Mask: mask24})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, ip1234, 0), key)

		require.NotNil(t, ci)
		require.NotNil(t, ci.m)
		require.NotEmpty(t, ci.m.Answer)

		a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
		assert.True(t, a.A.Equal(net.IP{3, 3, 3, 3}))
	})
}

func TestCache_getWithSubnet_mask(t *testing.T) {
	const testFQDN = "example.com."

	testIP := net.IP{176, 112, 191, 0}
	noMatchIP := net.IP{177, 112, 191, 0}

	// cachedIP/cidrMask network contains the testIP.
	const cidrMaskOnes = 20
	cidrMask := net.CIDRMask(cidrMaskOnes, netutil.IPv4BitLen)
	cachedIP := net.IP{176, 112, 176, 0}

	ansIP := net.IP{4, 4, 4, 4}

	c := newCache(testCacheSize, true, true)

	req := (&dns.Msg{}).SetQuestion(testFQDN, dns.TypeA)
	resp := (&dns.Msg{
		Answer: []dns.RR{newRR(t, testFQDN, dns.TypeA, 300, ansIP)},
	}).SetReply(req)

	// Cache IP network that contains the testIP.
	c.setWithSubnet(
		resp,
		upstreamWithAddr,
		&net.IPNet{IP: cachedIP, Mask: cidrMask},
	)

	t.Run("mask_matched", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{
			IP:   testIP,
			Mask: net.CIDRMask(24, netutil.IPv4BitLen),
		})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, testIP.Mask(cidrMask), cidrMaskOnes), key)

		require.NotNil(t, ci)
		require.NotNil(t, ci.m)
		require.NotEmpty(t, ci.m.Answer)

		a := testutil.RequireTypeAssert[*dns.A](t, ci.m.Answer[0])
		assert.True(t, a.A.Equal(ansIP))
	})

	t.Run("no_mask_matched", func(t *testing.T) {
		ci, expired, key := c.getWithSubnet(req, &net.IPNet{
			IP:   noMatchIP,
			Mask: net.CIDRMask(24, netutil.IPv4BitLen),
		})
		assert.False(t, expired)
		assert.Equal(t, msgToKeyWithSubnet(req, noMatchIP, 0), key)
		assert.Nil(t, ci)
	})
}

func TestCache_IsCacheable_negative(t *testing.T) {
	const someTTL = 3600

	msgHdr := func(rcode int) (hdr dns.MsgHdr) { return dns.MsgHdr{Id: dns.Id(), Rcode: rcode} }
	aQuestions := func(name string) []dns.Question {
		return []dns.Question{{
			Name:   name,
			Qtype:  dns.TypeA,
			Qclass: dns.ClassINET,
		}}
	}

	cnameAns := func(name, cname string) (rr dns.RR) {
		return &dns.CNAME{
			Hdr: dns.RR_Header{
				Name:   name,
				Rrtype: dns.TypeCNAME,
				Class:  dns.ClassINET,
				Ttl:    someTTL,
			},
			Target: cname,
		}
	}

	soaAns := func(name, ns, mbox string) (rr dns.RR) {
		return &dns.SOA{
			Hdr: dns.RR_Header{
				Name:   name,
				Rrtype: dns.TypeSOA,
				Class:  dns.ClassINET,
				Ttl:    someTTL,
			},
			Ns:   ns,
			Mbox: mbox,
		}
	}

	nsAns := func(name, ns string) (rr dns.RR) {
		return &dns.NS{
			Hdr: dns.RR_Header{
				Name:   name,
				Rrtype: dns.TypeNS,
				Class:  dns.ClassINET,
				Ttl:    someTTL,
			},
			Ns: ns,
		}
	}

	aAns := func(name string, a net.IP) (rr dns.RR) {
		return &dns.A{
			Hdr: dns.RR_Header{
				Name:   name,
				Rrtype: dns.TypeA,
				Class:  dns.ClassINET,
				Ttl:    someTTL,
			},
			A: a,
		}
	}

	const (
		hostname        = "AN.EXAMPLE."
		anotherHostname = "ANOTHER.EXAMPLE."
		cname           = "TRIPPLE.XX."
		mbox            = "HOSTMASTER.NS1.XX."
		ns1, ns2        = "NS1.XX.", "NS2.XX."
		xx              = "XX."
	)

	// See https://datatracker.ietf.org/doc/html/rfc2308.
	testCases := []struct {
		req     *dns.Msg
		name    string
		wantTTL uint32
	}{{
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeNameError),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
			Ns: []dns.RR{
				soaAns(xx, ns1, mbox),
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nxdomain_response_type_1",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeNameError),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
			Ns:       []dns.RR{soaAns("XX.", ns1, mbox)},
		},
		name:    "rfc2308_nxdomain_response_type_2",
		wantTTL: someTTL,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeNameError),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
		},
		name:    "rfc2308_nxdomain_response_type_3",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeNameError),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
			Ns: []dns.RR{
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nxdomain_response_type_4",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(hostname),
			Answer:   []dns.RR{cnameAns(hostname, cname)},
			Ns: []dns.RR{
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nxdomain_referral_response",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(anotherHostname),
			Ns: []dns.RR{
				soaAns(xx, ns1, mbox),
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nodata_response_type_1",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(anotherHostname),
			Ns:       []dns.RR{soaAns(xx, ns1, mbox)},
		},
		name:    "rfc2308_nodata_response_type_2",
		wantTTL: someTTL,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(anotherHostname),
		},
		name:    "rfc2308_nodata_response_type_3",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeSuccess),
			Question: aQuestions(anotherHostname),
			Ns: []dns.RR{
				nsAns(xx, ns1),
				nsAns(xx, ns2),
			},
			Extra: []dns.RR{
				aAns(ns1, net.IP{127, 0, 0, 2}),
				aAns(ns2, net.IP{127, 0, 0, 3}),
			},
		},
		name:    "rfc2308_nodata_referral_response",
		wantTTL: 0,
	}, {
		req: &dns.Msg{
			MsgHdr:   msgHdr(dns.RcodeServerFailure),
			Question: aQuestions(anotherHostname),
		},
		name:    "servfail_response",
		wantTTL: ServFailMaxCacheTTL,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			assert.Equal(t, tc.wantTTL, cacheTTL(tc.req))
		})
	}
}
07070100000045000081A4000000000000000000000001650C5921000025C2000000000000000000000000000000000000002000000000dnsproxy-0.55.0/proxy/config.gopackage proxy

import (
	"crypto/tls"
	"fmt"
	"net"
	"net/netip"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/ameshkov/dnscrypt/v2"
)

// UpstreamModeType - upstream mode
type UpstreamModeType int

const (
	// UModeLoadBalance - LoadBalance
	UModeLoadBalance UpstreamModeType = iota
	// UModeParallel - parallel queries to all configured upstream servers are enabled
	UModeParallel
	// UModeFastestAddr - use Fastest Address algorithm
	UModeFastestAddr
)

// BeforeRequestHandler is an optional custom handler called before DNS requests
// If it returns false, the request won't be processed at all
type BeforeRequestHandler func(p *Proxy, d *DNSContext) (bool, error)

// RequestHandler is an optional custom handler for DNS requests
// It is called instead of the default method (Proxy.Resolve())
// See handler_test.go for examples
type RequestHandler func(p *Proxy, d *DNSContext) error

// ResponseHandler is a callback method that is called when DNS query has been processed
// d -- current DNS query context (contains response if it was successful)
// err -- error (if any)
type ResponseHandler func(d *DNSContext, err error)

// Config contains all the fields necessary for proxy configuration
//
// TODO(a.garipov): Consider extracting conf blocks for better fieldalignment.
type Config struct {
	// Listeners
	// --

	UDPListenAddr         []*net.UDPAddr // if nil, then it does not listen for UDP
	TCPListenAddr         []*net.TCPAddr // if nil, then it does not listen for TCP
	HTTPSListenAddr       []*net.TCPAddr // if nil, then it does not listen for HTTPS (DoH)
	TLSListenAddr         []*net.TCPAddr // if nil, then it does not listen for TLS (DoT)
	QUICListenAddr        []*net.UDPAddr // if nil, then it does not listen for QUIC (DoQ)
	DNSCryptUDPListenAddr []*net.UDPAddr // if nil, then it does not listen for DNSCrypt
	DNSCryptTCPListenAddr []*net.TCPAddr // if nil, then it does not listen for DNSCrypt

	// Encryption configuration
	// --

	TLSConfig            *tls.Config    // necessary for TLS, HTTPS, QUIC
	HTTP3                bool           // if true, HTTPS server will also support HTTP/3
	DNSCryptProviderName string         // DNSCrypt provider name
	DNSCryptResolverCert *dnscrypt.Cert // DNSCrypt resolver certificate

	// Rate-limiting and anti-DNS amplification measures
	// --

	Ratelimit          int      // max number of requests per second from a given IP (0 to disable)
	RatelimitWhitelist []string // a list of whitelisted client IP addresses
	RefuseAny          bool     // if true, refuse ANY requests

	// TrustedProxies is the list of IP addresses and CIDR networks to
	// detect proxy servers addresses the DoH requests from which should be
	// handled.  The value of nil or an empty slice for this field makes
	// Proxy not trust any address.
	TrustedProxies []string

	// Upstream DNS servers and their settings
	// --

	// UpstreamConfig is a general set of DNS servers to forward requests to.
	UpstreamConfig *UpstreamConfig

	// PrivateRDNSUpstreamConfig is the set of upstream DNS servers for
	// resolving private IP addresses.  All the requests considered private will
	// be resolved via these upstream servers.  Such queries will finish with
	// [upstream.ErrNoUpstream] if it's empty.
	PrivateRDNSUpstreamConfig *UpstreamConfig

	// Fallbacks is a list of fallback resolvers.  Those will be used if the
	// general set fails responding.
	Fallbacks *UpstreamConfig

	// UpstreamMode determines the logic through which upstreams will be used.
	UpstreamMode UpstreamModeType

	// FastestPingTimeout is the timeout for waiting the first successful
	// dialing when the UpstreamMode is set to UModeFastestAddr.  Non-positive
	// value will be replaced with the default one.
	FastestPingTimeout time.Duration

	// BogusNXDomain is the set of networks used to transform responses into
	// NXDOMAIN ones if they contain at least a single IP address within these
	// networks.  It's similar to dnsmasq's "bogus-nxdomain".
	BogusNXDomain []*net.IPNet

	// Enable EDNS Client Subnet option DNS requests to the upstream server will
	// contain an OPT record with Client Subnet option.  If the original request
	// already has this option set, we pass it through as is.  Otherwise, we set
	// it ourselves using the client IP with subnet /24 (for IPv4) and /56 (for
	// IPv6).
	//
	// If the upstream server supports ECS, it sets subnet number in the
	// response.  This subnet number along with the client IP and other data is
	// used as a cache key.  Next time, if a client from the same subnet
	// requests this host name, we get the response from cache.  If another
	// client from a different subnet requests this host name, we pass his
	// request to the upstream server.
	//
	// If the upstream server doesn't support ECS (there's no subnet number in
	// response), this response will be cached for all clients.
	//
	// If client IP is private (i.e. not public), we don't add EDNS record into
	// a request.  And so there will be no EDNS record in response either.  We
	// store these responses in general cache (without subnet) so they will
	// never be used for clients with public IP addresses.
	EnableEDNSClientSubnet bool

	// EDNSAddr is the ECS IP used in request.
	EDNSAddr net.IP

	// Cache settings
	// --

	CacheEnabled   bool   // cache status
	CacheSizeBytes int    // Cache size (in bytes). Default: 64k
	CacheMinTTL    uint32 // Minimum TTL for DNS entries (in seconds).
	CacheMaxTTL    uint32 // Maximum TTL for DNS entries (in seconds).
	// CacheOptimistic defines if the optimistic cache mechanism should be
	// used.
	CacheOptimistic bool

	// Handlers (for the case when dnsproxy is used as a library)
	// --

	BeforeRequestHandler BeforeRequestHandler // callback that is called before each request
	RequestHandler       RequestHandler       // callback that can handle incoming DNS requests
	ResponseHandler      ResponseHandler      // response callback

	// Other settings
	// --

	// HTTPSServerName sets the Server header of the HTTPS server responses, if
	// not empty.
	HTTPSServerName string

	// MaxGoroutines is the maximum number of goroutines processing DNS
	// requests.  Important for mobile users.
	//
	// TODO(a.garipov): Rename this to something like
	// “MaxDNSRequestGoroutines” in a later major version, as it doesn't
	// actually limit all goroutines.
	MaxGoroutines int

	// The size of the read buffer on the underlying socket. Larger read buffers can handle
	// larger bursts of requests before packets get dropped.
	UDPBufferSize int

	// UseDNS64 enables DNS64 handling.  If true, proxy will translate IPv4
	// answers into IPv6 answers using first of DNS64Prefs.  Note also that PTR
	// requests for addresses within the specified networks are considered
	// private and will be forwarded as PrivateRDNSUpstreamConfig specifies.
	UseDNS64 bool

	// DNS64Prefs is the set of NAT64 prefixes used for DNS64 handling.  nil
	// value disables the feature.  An empty value will be interpreted as the
	// default Well-Known Prefix.
	DNS64Prefs []netip.Prefix

	// PreferIPv6 tells the proxy to prefer IPv6 addresses when bootstrapping
	// upstreams that use hostnames.
	PreferIPv6 bool
}

// validateConfig verifies that the supplied configuration is valid and returns an error if it's not
func (p *Proxy) validateConfig() error {
	err := p.validateListenAddrs()
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return err
	}

	err = p.UpstreamConfig.validate()
	if err != nil {
		return fmt.Errorf("validating general usptreams: %w", err)
	}

	// Allow both [Proxy.PrivateRDNSUpstreamConfig] and [Proxy.Fallbacks] to be
	// nil, but not empty.  nil means using the default values for those.

	err = p.PrivateRDNSUpstreamConfig.validate()
	if err != nil && !errors.Is(err, errNoDefaultUpstreams) {
		return fmt.Errorf("validating private RDNS upstreams: %w", err)
	}

	err = p.Fallbacks.validate()
	if err != nil && !errors.Is(err, errNoDefaultUpstreams) {
		return fmt.Errorf("validating fallbacks: %w", err)
	}

	if p.CacheMinTTL > 0 || p.CacheMaxTTL > 0 {
		log.Info("Cache TTL override is enabled. Min=%d, Max=%d", p.CacheMinTTL, p.CacheMaxTTL)
	}

	if p.Ratelimit > 0 {
		log.Info("Ratelimit is enabled and set to %d rps", p.Ratelimit)
	}

	if p.RefuseAny {
		log.Info("The server is configured to refuse ANY requests")
	}

	if len(p.BogusNXDomain) > 0 {
		log.Info("%d bogus-nxdomain IP specified", len(p.BogusNXDomain))
	}

	return nil
}

// validateListenAddrs returns an error if the addresses are not configured
// properly.
func (p *Proxy) validateListenAddrs() error {
	if !p.hasListenAddrs() {
		return errors.Error("no listen address specified")
	}

	if p.TLSConfig == nil {
		if p.TLSListenAddr != nil {
			return errors.Error("cannot create tls listener without tls config")
		}

		if p.HTTPSListenAddr != nil {
			return errors.Error("cannot create https listener without tls config")
		}

		if p.QUICListenAddr != nil {
			return errors.Error("cannot create quic listener without tls config")
		}
	}

	if (p.DNSCryptTCPListenAddr != nil || p.DNSCryptUDPListenAddr != nil) &&
		(p.DNSCryptResolverCert == nil || p.DNSCryptProviderName == "") {
		return errors.Error("cannot create dnscrypt listener without dnscrypt config")
	}

	return nil
}

// hasListenAddrs - is there any addresses to listen to?
func (p *Proxy) hasListenAddrs() bool {
	return p.UDPListenAddr != nil ||
		p.TCPListenAddr != nil ||
		p.TLSListenAddr != nil ||
		p.HTTPSListenAddr != nil ||
		p.QUICListenAddr != nil ||
		p.DNSCryptUDPListenAddr != nil ||
		p.DNSCryptTCPListenAddr != nil
}
07070100000046000081A4000000000000000000000001650C5921000026D6000000000000000000000000000000000000001F00000000dnsproxy-0.55.0/proxy/dns64.gopackage proxy

import (
	"fmt"
	"net"
	"net/netip"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/mathutil"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

const (
	// maxNAT64PrefixBitLen is the maximum length of a NAT64 prefix in bits.
	// See https://datatracker.ietf.org/doc/html/rfc6147#section-5.2.
	maxNAT64PrefixBitLen = 96

	// NAT64PrefixLength is the length of a NAT64 prefix in bytes.
	NAT64PrefixLength = net.IPv6len - net.IPv4len

	// maxDNS64SynTTL is the maximum TTL for synthesized DNS64 responses with no
	// SOA records in seconds.
	//
	// If the SOA RR was not delivered with the negative response to the AAAA
	// query, then the DNS64 SHOULD use the TTL of the original A RR or 600
	// seconds, whichever is shorter.
	//
	// See https://datatracker.ietf.org/doc/html/rfc6147#section-5.1.7.
	maxDNS64SynTTL uint32 = 600
)

// setupDNS64 initializes DNS64 settings, the NAT64 prefixes in particular.  If
// the DNS64 feature is enabled and no prefixes are configured, the default
// Well-Known Prefix is used, just like Section 5.2 of RFC 6147 prescribes.  Any
// configured set of prefixes discards the default Well-Known prefix unless it
// is specified explicitly.  Each prefix also validated to be a valid IPv6 CIDR
// with a maximum length of 96 bits.  The first specified prefix is then used to
// synthesize AAAA records.
func (p *Proxy) setupDNS64() (err error) {
	if !p.Config.UseDNS64 {
		return nil
	}

	l := len(p.Config.DNS64Prefs)
	if l == 0 {
		p.dns64Prefs = []netip.Prefix{dns64WellKnownPref}

		return nil
	}

	for i, pref := range p.Config.DNS64Prefs {
		if !pref.Addr().Is6() {
			return fmt.Errorf("prefix at index %d: %q is not an IPv6 prefix", i, pref)
		}

		if pref.Bits() > maxNAT64PrefixBitLen {
			return fmt.Errorf("prefix at index %d: %q is too long for DNS64", i, pref)
		}

		p.dns64Prefs = append(p.dns64Prefs, pref.Masked())
	}

	return nil
}

// checkDNS64 checks if DNS64 should be performed.  It returns a DNS64 request
// to resolve or nil if DNS64 is not desired.  It also filters resp to not
// contain any NAT64 excluded addresses in the answer section, if needed.  Both
// req and resp must not be nil.
//
// See https://datatracker.ietf.org/doc/html/rfc6147.
func (p *Proxy) checkDNS64(req, resp *dns.Msg) (dns64Req *dns.Msg) {
	if len(p.dns64Prefs) == 0 {
		return nil
	}

	q := req.Question[0]
	if q.Qtype != dns.TypeAAAA || q.Qclass != dns.ClassINET {
		// DNS64 operation for classes other than IN is undefined, and a DNS64
		// MUST behave as though no DNS64 function is configured.
		return nil
	}

	switch resp.Rcode {
	case dns.RcodeNameError:
		// A result with RCODE=3 (Name Error) is handled according to normal DNS
		// operation (which is normally to return the error to the client).
		return nil
	case dns.RcodeSuccess:
		// If resolver receives an answer with at least one AAAA record
		// containing an address outside any of the excluded range(s), then it
		// by default SHOULD build an answer section for a response including
		// only the AAAA record(s) that do not contain any of the addresses
		// inside the excluded ranges.
		var hasAnswers bool
		if resp.Answer, hasAnswers = p.filterNAT64Answers(resp.Answer); hasAnswers {
			return nil
		}
	default:
		// Any other RCODE is treated as though the RCODE were 0 and the answer
		// section were empty.
	}

	dns64Req = req.Copy()
	dns64Req.Id = dns.Id()
	dns64Req.Question[0].Qtype = dns.TypeA

	return dns64Req
}

// filterNAT64Answers filters out AAAA records that are within one of NAT64
// exclusion prefixes.  hasAnswers is true if the filtered slice contains at
// least a single AAAA answer not within the prefixes or a CNAME.
//
// TODO(e.burkov):  Remove prefs from args when old API is removed.
func (p *Proxy) filterNAT64Answers(
	rrs []dns.RR,
) (filtered []dns.RR, hasAnswers bool) {
	filtered = make([]dns.RR, 0, len(rrs))
	for _, ans := range rrs {
		switch ans := ans.(type) {
		case *dns.AAAA:
			addr, err := netutil.IPToAddrNoMapped(ans.AAAA)
			if err != nil {
				log.Error("proxy: bad aaaa record: %s", err)

				continue
			}

			if p.withinDNS64(addr) {
				// Filter the record.
				continue
			}

			filtered, hasAnswers = append(filtered, ans), true
		case *dns.CNAME, *dns.DNAME:
			// If the response contains a CNAME or a DNAME, then the CNAME or
			// DNAME chain is followed until the first terminating A or AAAA
			// record is reached.
			//
			// Just treat CNAME and DNAME responses as passable answers since
			// AdGuard Home doesn't follow any of these chains except the
			// dnsrewrite-defined ones.
			filtered, hasAnswers = append(filtered, ans), true
		default:
			filtered = append(filtered, ans)
		}
	}

	return filtered, hasAnswers
}

// synthDNS64 synthesizes a DNS64 response using the original response as a
// basis and modifying it with data from resp.  It returns true if the response
// was actually modified.
func (p *Proxy) synthDNS64(origReq, origResp, resp *dns.Msg) (ok bool) {
	if len(resp.Answer) == 0 {
		// If there is an empty answer, then the DNS64 responds to the original
		// querying client with the answer the DNS64 received to the original
		// (initiator's) query.
		return false
	}

	// The Time to Live (TTL) field is set to the minimum of the TTL of the
	// original A RR and the SOA RR for the queried domain.  If the original
	// response contains no SOA records, the minimum of the TTL of the original
	// A RR and [maxDNS64SynTTL] should be used.  See [maxDNS64SynTTL].
	soaTTL := maxDNS64SynTTL
	for _, rr := range origResp.Ns {
		if hdr := rr.Header(); hdr.Rrtype == dns.TypeSOA && hdr.Name == origReq.Question[0].Name {
			soaTTL = hdr.Ttl

			break
		}
	}

	newAns := make([]dns.RR, 0, len(resp.Answer))
	for _, ans := range resp.Answer {
		rr := p.synthRR(ans, soaTTL)
		if rr == nil {
			// The error should have already been logged.
			return false
		}

		newAns = append(newAns, rr)
	}

	origResp.Answer = newAns
	origResp.Ns = resp.Ns
	origResp.Extra = resp.Extra

	return true
}

// dns64WellKnownPref is the default prefix to use in an algorithmic mapping for
// DNS64.  See https://datatracker.ietf.org/doc/html/rfc6052#section-2.1.
var dns64WellKnownPref = netip.MustParsePrefix("64:ff9b::/96")

// withinDNS64 checks if ip is within one of the configured DNS64 prefixes.
//
// TODO(e.burkov):  We actually using bytes of only the first prefix from the
// set to construct the answer, so consider using some implementation of a
// prefix set for the rest.
func (p *Proxy) withinDNS64(ip netip.Addr) (ok bool) {
	for _, n := range p.dns64Prefs {
		if n.Contains(ip) {
			return true
		}
	}

	return false
}

// shouldStripDNS64 returns true if DNS64 is enabled and ip has either one of
// custom DNS64 prefixes or the Well-Known one.  This is intended to be used
// with PTR requests.
//
// The requirement is to match any Pref64::/n used at the site, and not merely
// the locally configured Pref64::/n.  This is because end clients could ask for
// a PTR record matching an address received through a different (site-provided)
// DNS64.
//
// See https://datatracker.ietf.org/doc/html/rfc6147#section-5.3.1.
func (p *Proxy) shouldStripDNS64(ip net.IP) (ok bool) {
	if len(p.dns64Prefs) == 0 {
		return false
	}

	addr, err := netutil.IPToAddr(ip, netutil.AddrFamilyIPv6)
	if err != nil {
		return false
	}

	switch {
	case p.withinDNS64(addr):
		log.Debug("proxy: %s is within DNS64 custom prefix set", ip)
	case dns64WellKnownPref.Contains(addr):
		log.Debug("proxy: %s is within DNS64 well-known prefix", ip)
	default:
		return false
	}

	return true
}

// mapDNS64 maps addr to IPv6 address using configured DNS64 prefix.  addr must
// be a valid IPv4.  It panics, if there are no configured DNS64 prefixes,
// because synthesis should not be performed unless DNS64 function enabled.
//
// TODO(e.burkov):  Remove pref from args when old API is removed.
func (p *Proxy) mapDNS64(addr netip.Addr) (mapped net.IP) {
	// Don't mask the address here since it should have already been masked on
	// initialization stage.
	prefData := p.dns64Prefs[0].Addr().As16()
	addrData := addr.As4()

	mapped = make(net.IP, net.IPv6len)
	copy(mapped[:NAT64PrefixLength], prefData[:])
	copy(mapped[NAT64PrefixLength:], addrData[:])

	return mapped
}

// synthRR synthesizes a DNS64 resource record in compliance with RFC 6147.  If
// rr is not an A record, it's returned as is.  A records are modified to become
// a DNS64-synthesized AAAA records, and the TTL is set according to the
// original TTL of a record and soaTTL.  It returns nil on invalid A records.
func (p *Proxy) synthRR(rr dns.RR, soaTTL uint32) (result dns.RR) {
	aResp, ok := rr.(*dns.A)
	if !ok {
		return rr
	}

	addr, err := netutil.IPToAddr(aResp.A, netutil.AddrFamilyIPv4)
	if err != nil {
		log.Error("proxy: bad a record: %s", err)

		return nil
	}

	aaaa := &dns.AAAA{
		Hdr: dns.RR_Header{
			Name:   aResp.Hdr.Name,
			Rrtype: dns.TypeAAAA,
			Class:  aResp.Hdr.Class,
			Ttl:    mathutil.Min(aResp.Hdr.Ttl, soaTTL),
		},
		AAAA: p.mapDNS64(addr),
	}

	return aaaa
}

// performDNS64 returns the upstream that was used to perform DNS64 request, or
// nil, if the request was not performed.
func (p *Proxy) performDNS64(
	origReq *dns.Msg,
	origResp *dns.Msg,
	upstreams []upstream.Upstream,
) (u upstream.Upstream) {
	if origResp == nil {
		return nil
	}

	dns64Req := p.checkDNS64(origReq, origResp)
	if dns64Req == nil {
		return nil
	}

	host := origReq.Question[0].Name
	log.Debug("proxy: received an empty aaaa response for %q, checking dns64", host)

	dns64Resp, u, err := p.exchange(dns64Req, upstreams)
	if err != nil {
		log.Error("proxy: dns64 request failed: %s", err)

		return nil
	}

	if dns64Resp != nil && p.synthDNS64(origReq, origResp, dns64Resp) {
		log.Debug("dnsforward: synthesized aaaa response for %q", host)

		return u
	}

	return nil
}
07070100000047000081A4000000000000000000000001650C592100002324000000000000000000000000000000000000002400000000dnsproxy-0.55.0/proxy/dns64_test.gopackage proxy

import (
	"net"
	"net/netip"
	"sync"
	"testing"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

const ipv4OnlyFqdn = "ipv4.only."

func TestDNS64Race(t *testing.T) {
	log.SetLevel(log.DEBUG)
	dnsProxy := createTestProxy(t, nil)

	ans := newRR(t, ipv4OnlyFqdn, dns.TypeA, 3600, net.ParseIP("1.2.3.4"))
	ups := upstreamFunc(func(req *dns.Msg) (resp *dns.Msg, err error) {
		resp = (&dns.Msg{}).SetReply(req)
		if req.Question[0].Qtype == dns.TypeA {
			resp.Answer = []dns.RR{dns.Copy(ans)}
		}

		return resp, nil
	})

	dnsProxy.UseDNS64 = true
	// Valid NAT-64 prefix for 2001:67c:27e4:15::64 server.
	dnsProxy.DNS64Prefs = []netip.Prefix{netip.MustParsePrefix("2001:67c:27e4:1064::/96")}
	dnsProxy.UpstreamConfig.Upstreams = []upstream.Upstream{ups}

	require.NoError(t, dnsProxy.Start())
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	syncCh := make(chan struct{})

	// Send requests.
	g := &sync.WaitGroup{}
	g.Add(testMessagesCount)

	addr := dnsProxy.Addr(ProtoTCP).String()
	for i := 0; i < testMessagesCount; i++ {
		// The [dns.Conn] isn't safe for concurrent use despite the requirements
		// from the [net.Conn] documentation.
		conn, err := dns.Dial("tcp", addr)
		require.NoError(t, err)

		go sendTestAAAAMessageAsync(conn, g, ipv4OnlyFqdn, syncCh)
	}

	close(syncCh)
	g.Wait()
}

func sendTestAAAAMessageAsync(conn *dns.Conn, g *sync.WaitGroup, fqdn string, syncCh chan struct{}) {
	pt := testutil.PanicT{}

	defer g.Done()

	req := (&dns.Msg{}).SetQuestion(fqdn, dns.TypeAAAA)
	<-syncCh

	err := conn.WriteMsg(req)
	require.NoError(pt, err)

	res, err := conn.ReadMsg()
	require.NoError(pt, err)
	require.Equal(pt, res.Rcode, dns.RcodeSuccess)
	require.NotEmpty(pt, res.Answer)

	require.IsType(pt, &dns.AAAA{}, res.Answer[0])
}

// newRR is a helper that creates a new dns.RR with the given name, qtype,
// ttl and value.  It fails the test if the qtype is not supported or the type
// of value doesn't match the qtype.
func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns.RR) {
	t.Helper()

	switch qtype {
	case dns.TypeA:
		rr = &dns.A{A: testutil.RequireTypeAssert[net.IP](t, val)}
	case dns.TypeAAAA:
		rr = &dns.AAAA{AAAA: testutil.RequireTypeAssert[net.IP](t, val)}
	case dns.TypeCNAME:
		rr = &dns.CNAME{Target: testutil.RequireTypeAssert[string](t, val)}
	case dns.TypeSOA:
		rr = &dns.SOA{
			Ns:      "ns." + name,
			Mbox:    "hostmaster." + name,
			Serial:  1,
			Refresh: 1,
			Retry:   1,
			Expire:  1,
			Minttl:  1,
		}
	case dns.TypePTR:
		rr = &dns.PTR{Ptr: testutil.RequireTypeAssert[string](t, val)}
	default:
		t.Fatalf("unsupported qtype: %d", qtype)
	}

	*rr.Header() = dns.RR_Header{
		Name:   name,
		Rrtype: qtype,
		Class:  dns.ClassINET,
		Ttl:    ttl,
	}

	return rr
}

// upstreamFunc is a helper type that implements the [upstream.Upstream]
// interface.
type upstreamFunc func(req *dns.Msg) (resp *dns.Msg, err error)

// type check
var _ upstream.Upstream = upstreamFunc(nil)

// Exchange implements the [upstream.Upstream] interface for upstreamFunc.
func (u upstreamFunc) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { return u(req) }

// Address implements the [upstream.Upstream] interface for upstreamFunc.
func (u upstreamFunc) Address() (addr string) { return "func.upstream" }

// Close implements the [upstream.Upstream] interface for upstreamFunc.
func (u upstreamFunc) Close() (err error) { return nil }

func TestProxy_Resolve_dns64(t *testing.T) {
	const (
		ipv6Domain    = "ipv6.only."
		soaDomain     = "ipv4.soa."
		mappedDomain  = "filterable.ipv6."
		anotherDomain = "another.domain."

		pointedDomain = "local1234.ipv4."
		globDomain    = "real1234.ipv4."
	)

	someIPv4 := net.IP{1, 2, 3, 4}
	someIPv6 := net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
	mappedIPv6 := net.ParseIP("64:ff9b::102:304")

	ptr64Domain, err := netutil.IPToReversedAddr(mappedIPv6)
	require.NoError(t, err)
	ptr64Domain = dns.Fqdn(ptr64Domain)

	ptrGlobDomain, err := netutil.IPToReversedAddr(someIPv4)
	require.NoError(t, err)
	ptrGlobDomain = dns.Fqdn(ptrGlobDomain)

	cliIP := &net.TCPAddr{
		IP:   net.IP{192, 168, 1, 1},
		Port: 1234,
	}

	const (
		sectionAnswer = iota
		sectionAuthority
		sectionAdditional

		sectionsNum
	)

	// answerMap is a convenience alias for describing the upstream response for
	// a given question type.
	type answerMap = map[uint16][sectionsNum][]dns.RR

	pt := testutil.PanicT{}
	newUps := func(answers answerMap) (u upstream.Upstream) {
		return upstreamFunc(func(req *dns.Msg) (resp *dns.Msg, err error) {
			q := req.Question[0]
			require.Contains(pt, answers, q.Qtype)

			answer := answers[q.Qtype]

			resp = (&dns.Msg{}).SetReply(req)
			resp.Answer = answer[sectionAnswer]
			resp.Ns = answer[sectionAuthority]
			resp.Extra = answer[sectionAdditional]

			return resp, nil
		})
	}

	localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain)
	localUps := upstreamFunc(func(req *dns.Msg) (resp *dns.Msg, err error) {
		require.Equal(pt, req.Question[0].Name, ptr64Domain)
		resp = (&dns.Msg{}).SetReply(req)
		resp.Answer = []dns.RR{localRR}

		return resp, nil
	})

	testCases := []struct {
		name    string
		qname   string
		upsAns  answerMap
		wantAns []dns.RR
		qtype   uint16
	}{{
		name:  "simple_a",
		qname: ipv4OnlyFqdn,
		upsAns: answerMap{
			dns.TypeA: {
				sectionAnswer: {newRR(t, ipv4OnlyFqdn, dns.TypeA, 3600, someIPv4)},
			},
			dns.TypeAAAA: {},
		},
		wantAns: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Name:   ipv4OnlyFqdn,
				Rrtype: dns.TypeA,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			A: someIPv4,
		}},
		qtype: dns.TypeA,
	}, {
		name:  "simple_aaaa",
		qname: ipv6Domain,
		upsAns: answerMap{
			dns.TypeA: {},
			dns.TypeAAAA: {
				sectionAnswer: {newRR(t, ipv6Domain, dns.TypeAAAA, 3600, someIPv6)},
			},
		},
		wantAns: []dns.RR{&dns.AAAA{
			Hdr: dns.RR_Header{
				Name:   ipv6Domain,
				Rrtype: dns.TypeAAAA,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			AAAA: someIPv6,
		}},
		qtype: dns.TypeAAAA,
	}, {
		name:  "actual_dns64",
		qname: ipv4OnlyFqdn,
		upsAns: answerMap{
			dns.TypeA: {
				sectionAnswer: {newRR(t, ipv4OnlyFqdn, dns.TypeA, 3600, someIPv4)},
			},
			dns.TypeAAAA: {},
		},
		wantAns: []dns.RR{&dns.AAAA{
			Hdr: dns.RR_Header{
				Name:   ipv4OnlyFqdn,
				Rrtype: dns.TypeAAAA,
				Class:  dns.ClassINET,
				Ttl:    maxDNS64SynTTL,
			},
			AAAA: mappedIPv6,
		}},
		qtype: dns.TypeAAAA,
	}, {
		name:  "actual_dns64_soattl",
		qname: soaDomain,
		upsAns: answerMap{
			dns.TypeA: {
				sectionAnswer: {newRR(t, soaDomain, dns.TypeA, 3600, someIPv4)},
			},
			dns.TypeAAAA: {
				sectionAuthority: {newRR(t, soaDomain, dns.TypeSOA, maxDNS64SynTTL+50, nil)},
			},
		},
		wantAns: []dns.RR{&dns.AAAA{
			Hdr: dns.RR_Header{
				Name:   soaDomain,
				Rrtype: dns.TypeAAAA,
				Class:  dns.ClassINET,
				Ttl:    maxDNS64SynTTL + 50,
			},
			AAAA: mappedIPv6,
		}},
		qtype: dns.TypeAAAA,
	}, {
		name:  "filtered",
		qname: mappedDomain,
		upsAns: answerMap{
			dns.TypeA: {},
			dns.TypeAAAA: {
				sectionAnswer: {
					newRR(t, mappedDomain, dns.TypeAAAA, 3600, net.ParseIP("64:ff9b::506:708")),
					newRR(t, mappedDomain, dns.TypeCNAME, 3600, anotherDomain),
				},
			},
		},
		wantAns: []dns.RR{&dns.CNAME{
			Hdr: dns.RR_Header{
				Name:   mappedDomain,
				Rrtype: dns.TypeCNAME,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			Target: anotherDomain,
		}},
		qtype: dns.TypeAAAA,
	}, {
		name:   "ptr",
		qname:  ptr64Domain,
		upsAns: nil,
		wantAns: []dns.RR{&dns.PTR{
			Hdr: dns.RR_Header{
				Name:   ptr64Domain,
				Rrtype: dns.TypePTR,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			Ptr: pointedDomain,
		}},
		qtype: dns.TypePTR,
	}, {
		name:  "ptr_glob",
		qname: ptrGlobDomain,
		upsAns: answerMap{
			dns.TypePTR: {
				sectionAnswer: {newRR(t, ptrGlobDomain, dns.TypePTR, 3600, globDomain)},
			},
		},
		wantAns: []dns.RR{&dns.PTR{
			Hdr: dns.RR_Header{
				Name:   ptrGlobDomain,
				Rrtype: dns.TypePTR,
				Class:  dns.ClassINET,
				Ttl:    3600,
			},
			Ptr: globDomain,
		}},
		qtype: dns.TypePTR,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			p := createTestProxy(t, nil)
			p.Config.UpstreamConfig.Upstreams = []upstream.Upstream{newUps(tc.upsAns)}
			p.Config.PrivateRDNSUpstreamConfig = &UpstreamConfig{
				Upstreams: []upstream.Upstream{localUps},
			}
			p.Config.UseDNS64 = true

			require.NoError(t, p.Start())
			testutil.CleanupAndRequireSuccess(t, p.Stop)

			req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
			dctx := &DNSContext{
				Req:  req,
				Addr: cliIP,
			}

			err = p.Resolve(dctx)
			require.NoError(t, err)

			res := dctx.Res
			require.NotNil(t, res)
			assert.Equal(t, tc.wantAns, res.Answer)
		})
	}
}
07070100000048000081A4000000000000000000000001650C592100001169000000000000000000000000000000000000002400000000dnsproxy-0.55.0/proxy/dnscontext.gopackage proxy

import (
	"net"
	"net/http"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/mathutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
)

// DNSContext represents a DNS request message context
type DNSContext struct {
	// StartTime is the moment when request processing started.
	StartTime time.Time

	// Conn is the underlying client connection.  It is nil if Proto is
	// ProtoDNSCrypt, ProtoHTTPS, or ProtoQUIC.
	Conn net.Conn

	// QUICConnection is the QUIC session from which we got the query.  For
	// ProtoQUIC only.
	QUICConnection quic.Connection

	// QUICStream is the QUIC stream from which we got the query.  For
	// ProtoQUIC only.
	QUICStream quic.Stream

	// Addr is the address of the client.
	Addr net.Addr

	// Upstream is the upstream that resolved the request.  In case of cached
	// response it's nil.
	Upstream upstream.Upstream

	// DNSCryptResponseWriter - necessary to respond to a DNSCrypt query
	DNSCryptResponseWriter dnscrypt.ResponseWriter

	// HTTPResponseWriter - HTTP response writer (for DoH only)
	HTTPResponseWriter http.ResponseWriter

	// HTTPRequest - HTTP request (for DoH only)
	HTTPRequest *http.Request

	// ReqECS is the EDNS Client Subnet used in the request.
	ReqECS *net.IPNet

	// CustomUpstreamConfig is only used for current request.  The Resolve
	// method of Proxy uses it instead of the default servers if it's not nil.
	CustomUpstreamConfig *UpstreamConfig

	// Req is the request message.
	Req *dns.Msg
	// Res is the response message.
	Res *dns.Msg

	Proto Proto

	// CachedUpstreamAddr is the address of the upstream which the answer was
	// cached with.  It's empty for responses resolved by the upstream server.
	CachedUpstreamAddr string

	// localIP - local IP address (for UDP socket to call udpMakeOOBWithSrc)
	localIP net.IP

	// DoQVersion is the DoQ protocol version. It can (and should) be read from
	// ALPN, but in the current version we also use the way DNS messages are
	// encoded as a signal.
	DoQVersion DoQVersion

	// RequestID is an opaque numerical identifier of this request that is
	// guaranteed to be unique across requests processed by a single Proxy
	// instance.
	RequestID uint64

	// udpSize is the UDP buffer size from request's EDNS0 RR if presented,
	// or default otherwise.
	udpSize uint16

	// adBit is the authenticated data flag from the request.
	adBit bool
	// hasEDNS0 reflects if the request has EDNS0 RRs.
	hasEDNS0 bool
	// doBit is the DNSSEC OK flag from request's EDNS0 RR if presented.
	doBit bool
}

// calcFlagsAndSize lazily calculates some values required for Resolve method.
func (dctx *DNSContext) calcFlagsAndSize() {
	if dctx.udpSize != 0 || dctx.Req == nil {
		return
	}

	dctx.adBit = dctx.Req.AuthenticatedData
	dctx.udpSize = defaultUDPBufSize
	if o := dctx.Req.IsEdns0(); o != nil {
		dctx.hasEDNS0 = true
		dctx.doBit = o.Do()
		dctx.udpSize = o.UDPSize()
	}
}

// scrub prepares the d.Res to be written.  Truncation is applied as well if
// necessary.
func (dctx *DNSContext) scrub() {
	if dctx.Res == nil || dctx.Req == nil {
		return
	}

	// We should guarantee that all the values we need are calculated.
	dctx.calcFlagsAndSize()

	// RFC-6891 (https://tools.ietf.org/html/rfc6891) states that response
	// mustn't contain an EDNS0 RR if the request doesn't include it.
	//
	// See https://github.com/AdguardTeam/dnsproxy/issues/132.
	if dctx.hasEDNS0 && dctx.Res.IsEdns0() == nil {
		dctx.Res.SetEdns0(dctx.udpSize, dctx.doBit)
	}

	dctx.Res.Truncate(dnsSize(dctx.Proto == ProtoUDP, dctx.Req))
	// Some devices require DNS message compression.
	dctx.Res.Compress = true
}

// dnsSize returns the buffer size advertised in the requests OPT record.  When
// the request is over TCP, it returns the maximum allowed size of 64KiB.
func dnsSize(isUDP bool, r *dns.Msg) (size int) {
	if !isUDP {
		return dns.MaxMsgSize
	}

	var size16 uint16
	if o := r.IsEdns0(); o != nil {
		size16 = o.UDPSize()
	}

	return int(mathutil.Max(dns.MinMsgSize, size16))
}

// DoQVersion is an enumeration with supported DoQ versions.
type DoQVersion int

const (
	// DoQv1Draft represents old DoQ draft versions that do not send a 2-octet
	// prefix with the DNS message length.
	//
	// TODO(ameshkov): remove in the end of 2024.
	DoQv1Draft DoQVersion = 0x00

	// DoQv1 represents DoQ v1.0: https://www.rfc-editor.org/rfc/rfc9250.html.
	DoQv1 DoQVersion = 0x01
)
07070100000049000081A4000000000000000000000001650C592100000219000000000000000000000000000000000000002000000000dnsproxy-0.55.0/proxy/errors.go//go:build !plan9
// +build !plan9

package proxy

import (
	"syscall"

	"github.com/AdguardTeam/golibs/errors"
)

// isEPIPE checks if the underlying error is EPIPE.  syscall.EPIPE exists on all
// OSes except for Plan 9.  Validate with:
//
//	$ for os in $(go tool dist list | cut -d / -f 1 | sort -u)
//	do
//	        echo -n "$os"
//	        env GOOS="$os" go doc syscall.EPIPE | grep -F -e EPIPE
//	done
//
// For the Plan 9 version see ./errors_plan9.go.
func isEPIPE(err error) (ok bool) {
	return errors.Is(err, syscall.EPIPE)
}
0707010000004A000081A4000000000000000000000001650C592100000232000000000000000000000000000000000000002600000000dnsproxy-0.55.0/proxy/errors_plan9.go//go:build plan9
// +build plan9

package proxy

import "strings"

// isEPIPE checks if the underlying error is EPIPE.  Plan 9 relies on error
// strings instead of error codes.  I couldn't find the exact constant with the
// text returned by a write on a closed socket, but it seems to be "sys: write
// on closed pipe".  See Plan 9's "man 2 notify".
//
// We don't currently support Plan 9, so it's not critical, but when we do, this
// needs to be rechecked.
func isEPIPE(err error) (ok bool) {
	return strings.Contains(err.Error(), "write on closed pipe")
}
0707010000004B000081A4000000000000000000000001650C5921000002D1000000000000000000000000000000000000002500000000dnsproxy-0.55.0/proxy/errors_test.go//go:build !plan9
// +build !plan9

package proxy

import (
	"fmt"
	"syscall"
	"testing"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/stretchr/testify/assert"
)

func TestIsEPIPE(t *testing.T) {
	type testCase struct {
		err  error
		name string
		want bool
	}

	testCases := []testCase{{
		name: "nil",
		err:  nil,
		want: false,
	}, {
		name: "epipe",
		err:  syscall.EPIPE,
		want: true,
	}, {
		name: "not_epipe",
		err:  errors.Error("test error"),
		want: false,
	}, {
		name: "wrapped_epipe",
		err:  fmt.Errorf("test error: %w", syscall.EPIPE),
		want: true,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			got := isEPIPE(tc.err)
			assert.Equal(t, tc.want, got)
		})
	}
}
0707010000004C000081A4000000000000000000000001650C592100000AB3000000000000000000000000000000000000002200000000dnsproxy-0.55.0/proxy/exchange.gopackage proxy

import (
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
	"golang.org/x/exp/slices"
)

// exchange -- sends DNS query to the upstream DNS server and returns the response
func (p *Proxy) exchange(req *dns.Msg, upstreams []upstream.Upstream) (reply *dns.Msg, u upstream.Upstream, err error) {
	qtype := req.Question[0].Qtype
	if p.UpstreamMode == UModeFastestAddr && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
		reply, u, err = p.fastestAddr.ExchangeFastest(req, upstreams)
		return
	}

	if p.UpstreamMode == UModeParallel {
		reply, u, err = upstream.ExchangeParallel(upstreams, req)
		return
	}

	// UModeLoadBalance goes below

	if len(upstreams) == 1 {
		u = upstreams[0]
		reply, _, err = exchangeWithUpstream(u, req)
		return
	}

	// sort upstreams by rtt from fast to slow
	sortedUpstreams := p.getSortedUpstreams(upstreams)

	errs := []error{}
	for _, dnsUpstream := range sortedUpstreams {
		var elapsed int
		reply, elapsed, err = exchangeWithUpstream(dnsUpstream, req)
		if err == nil {
			p.updateRtt(dnsUpstream.Address(), elapsed)

			return reply, dnsUpstream, err
		}
		errs = append(errs, err)
		p.updateRtt(dnsUpstream.Address(), int(defaultTimeout/time.Millisecond))
	}

	return nil, nil, errors.List("all upstreams failed to exchange request", errs...)
}

func (p *Proxy) getSortedUpstreams(u []upstream.Upstream) []upstream.Upstream {
	// clone upstreams list to avoid race conditions
	clone := slices.Clone(u)

	p.rttLock.Lock()
	defer p.rttLock.Unlock()

	slices.SortFunc(clone, func(a, b upstream.Upstream) (res int) {
		// TODO(d.kolyshev): Use upstreams for sort comparing.
		return p.upstreamRttStats[a.Address()] - p.upstreamRttStats[b.Address()]
	})

	return clone
}

// exchangeWithUpstream returns result of Exchange with elapsed time
func exchangeWithUpstream(u upstream.Upstream, req *dns.Msg) (*dns.Msg, int, error) {
	startTime := time.Now()
	reply, err := u.Exchange(req)
	elapsed := time.Since(startTime)
	if err != nil {
		log.Error(
			"upstream %s failed to exchange %s in %s. Cause: %s",
			u.Address(),
			req.Question[0].String(),
			elapsed,
			err,
		)
	} else {
		log.Tracef(
			"upstream %s successfully finished exchange of %s. Elapsed %s.",
			u.Address(),
			req.Question[0].String(),
			elapsed,
		)
	}

	return reply, int(elapsed.Milliseconds()), err
}

// updateRtt updates rtt in upstreamRttStats for given address
func (p *Proxy) updateRtt(address string, rtt int) {
	p.rttLock.Lock()
	defer p.rttLock.Unlock()

	if p.upstreamRttStats == nil {
		p.upstreamRttStats = map[string]int{}
	}
	p.upstreamRttStats[address] = (p.upstreamRttStats[address] + rtt) / 2
}
0707010000004D000081A4000000000000000000000001650C59210000060D000000000000000000000000000000000000002600000000dnsproxy-0.55.0/proxy/handler_test.gopackage proxy

import (
	"sync"
	"testing"
	"time"

	"github.com/miekg/dns"
)

func TestFilteringHandler(t *testing.T) {
	// Initializing the test middleware
	m := sync.RWMutex{}
	blockResponse := false

	// Prepare the proxy server
	dnsProxy := createTestProxy(t, nil)
	dnsProxy.RequestHandler = func(p *Proxy, d *DNSContext) error {
		m.Lock()
		defer m.Unlock()

		if !blockResponse {
			// Use the default Resolve method if response is not blocked
			return p.Resolve(d)
		}

		resp := dns.Msg{}
		resp.SetRcode(d.Req, dns.RcodeNotImplemented)
		resp.RecursionAvailable = true

		// Set the response right away
		d.Res = &resp
		return nil
	}

	// Start listening
	err := dnsProxy.Start()
	if err != nil {
		t.Fatalf("cannot start the DNS proxy: %s", err)
	}

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{Net: "udp", Timeout: 500 * time.Millisecond}

	// Send the first message (not blocked)
	req := createTestMessage()

	r, _, err := client.Exchange(req, addr.String())
	if err != nil {
		t.Fatalf("error in the first request: %s", err)
	}
	requireResponse(t, req, r)

	// Now send the second and make sure it is blocked
	m.Lock()
	blockResponse = true
	m.Unlock()

	r, _, err = client.Exchange(req, addr.String())
	if err != nil {
		t.Fatalf("error in the second request: %s", err)
	}
	if r.Rcode != dns.RcodeNotImplemented {
		t.Fatalf("second request was not blocked")
	}

	// Stop the proxy
	err = dnsProxy.Stop()
	if err != nil {
		t.Fatalf("cannot stop the DNS proxy: %s", err)
	}
}
0707010000004E000081A4000000000000000000000001650C592100001037000000000000000000000000000000000000002100000000dnsproxy-0.55.0/proxy/helpers.gopackage proxy

import (
	"net"

	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

const retryNoError = 60 // Retry time for NoError SOA

// CheckDisabledAAAARequest checks if AAAA requests should be disabled or not and sets NoError empty response to given DNSContext if needed
func CheckDisabledAAAARequest(ctx *DNSContext, ipv6Disabled bool) bool {
	if ipv6Disabled && ctx.Req.Question[0].Qtype == dns.TypeAAAA {
		log.Debug("IPv6 is disabled. Reply with NoError to %s AAAA request", ctx.Req.Question[0].Name)
		ctx.Res = genEmptyNoError(ctx.Req)
		return true
	}

	return false
}

// GenEmptyMessage generates empty message with given response code and retry time
func GenEmptyMessage(request *dns.Msg, rCode int, retry uint32) *dns.Msg {
	resp := dns.Msg{}
	resp.SetRcode(request, rCode)
	resp.RecursionAvailable = true
	resp.Ns = genSOA(request, retry)
	return &resp
}

// genEmptyNoError returns response without answer and NoError RCode
func genEmptyNoError(request *dns.Msg) *dns.Msg {
	return GenEmptyMessage(request, dns.RcodeSuccess, retryNoError)
}

// genSOA returns SOA for an authority section
func genSOA(request *dns.Msg, retry uint32) []dns.RR {
	zone := ""
	if len(request.Question) > 0 {
		zone = request.Question[0].Name
	}

	soa := dns.SOA{
		// values copied from verisign's nonexistent .com domain
		// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
		Refresh: 1800,
		Retry:   retry,
		Expire:  604800,
		Minttl:  86400,
		// copied from AdGuard DNS
		Ns:     "fake-for-negative-caching.adguard.com.",
		Serial: 100500,
		// rest is request-specific
		Hdr: dns.RR_Header{
			Name:   zone,
			Rrtype: dns.TypeSOA,
			Ttl:    10,
			Class:  dns.ClassINET,
		},
	}
	soa.Mbox = "hostmaster."
	if len(zone) > 0 && zone[0] != '.' {
		soa.Mbox += zone
	}
	return []dns.RR{&soa}
}

// ecsFromMsg returns the subnet from EDNS Client Subnet option of m if any.
func ecsFromMsg(m *dns.Msg) (subnet *net.IPNet, scope int) {
	opt := m.IsEdns0()
	if opt == nil {
		return nil, 0
	}

	var ip net.IP
	var mask net.IPMask
	for _, e := range opt.Option {
		sn, ok := e.(*dns.EDNS0_SUBNET)
		if !ok {
			continue
		}

		switch sn.Family {
		case 1:
			ip = sn.Address.To4()
			mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv4BitLen)
		case 2:
			ip = sn.Address
			mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv6BitLen)
		default:
			continue
		}

		return &net.IPNet{IP: ip, Mask: mask}, int(sn.SourceScope)
	}

	return nil, 0
}

// setECS sets the EDNS client subnet option based on ip and scope into m.  It
// returns masked IP and mask length.
func setECS(m *dns.Msg, ip net.IP, scope uint8) (subnet *net.IPNet) {
	const (
		// defaultECSv4 is the default length of network mask for IPv4 address
		// in ECS option.
		defaultECSv4 = 24

		// defaultECSv6 is the default length of network mask for IPv6 address
		// in ECS.  The size of 7 octets is chosen as a reasonable minimum since
		// at least Google's public DNS refuses requests containing the options
		// with longer network masks.
		defaultECSv6 = 56
	)

	e := &dns.EDNS0_SUBNET{
		Code:        dns.EDNS0SUBNET,
		SourceScope: scope,
	}

	subnet = &net.IPNet{}
	if ip4 := ip.To4(); ip4 != nil {
		e.Family = 1
		e.SourceNetmask = defaultECSv4
		subnet.Mask = net.CIDRMask(defaultECSv4, netutil.IPv4BitLen)
		ip = ip4
	} else {
		// Assume the IP address has already been validated.
		e.Family = 2
		e.SourceNetmask = defaultECSv6
		subnet.Mask = net.CIDRMask(defaultECSv6, netutil.IPv6BitLen)
	}
	subnet.IP = ip.Mask(subnet.Mask)
	e.Address = subnet.IP

	// If OPT record already exists so just add EDNS option inside it.  Note
	// that servers may return FORMERR if they meet several OPT RRs.
	if opt := m.IsEdns0(); opt != nil {
		opt.Option = append(opt.Option, e)

		return subnet
	}

	// Create an OPT record and add EDNS option inside it.
	o := &dns.OPT{
		Hdr: dns.RR_Header{
			Name:   ".",
			Rrtype: dns.TypeOPT,
		},
		Option: []dns.EDNS0{e},
	}
	o.SetUDPSize(4096)
	m.Extra = append(m.Extra, o)

	return subnet
}
0707010000004F000081A4000000000000000000000001650C592100000780000000000000000000000000000000000000002000000000dnsproxy-0.55.0/proxy/lookup.gopackage proxy

import (
	"net"

	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/golibs/errors"

	"github.com/miekg/dns"
)

// helper struct to pass results of lookupIPAddr function
type lookupResult struct {
	resp *dns.Msg
	err  error
}

func (p *Proxy) lookupIPAddr(host string, qtype uint16, ch chan *lookupResult) {
	req := &dns.Msg{}
	req.Id = dns.Id()
	req.RecursionDesired = true
	req.Question = []dns.Question{
		{
			Name:   host,
			Qtype:  qtype,
			Qclass: dns.ClassINET,
		},
	}

	d := p.newDNSContext(ProtoUDP, req)
	err := p.Resolve(d)
	ch <- &lookupResult{d.Res, err}
}

// ErrEmptyHost is returned by LookupIPAddr when the host is empty and can't be
// resolved.
const ErrEmptyHost = errors.Error("host is empty")

// LookupIPAddr resolves the specified host IP addresses
// It sends two DNS queries (A and AAAA) in parallel and returns both results
func (p *Proxy) LookupIPAddr(host string) ([]net.IPAddr, error) {
	if host == "" {
		return nil, ErrEmptyHost
	}

	host = dns.Fqdn(host)

	ch := make(chan *lookupResult)
	go p.lookupIPAddr(host, dns.TypeA, ch)
	go p.lookupIPAddr(host, dns.TypeAAAA, ch)

	var ipAddrs []net.IPAddr
	var errs []error
	for n := 0; n < 2; n++ {
		result := <-ch
		if result.err != nil {
			errs = append(errs, result.err)
		} else {
			// Copy IP addresses from dns.RR to the resulting IP slice.
			appendIPAddrs(&ipAddrs, result.resp.Answer)
		}
	}

	if len(ipAddrs) == 0 && len(errs) != 0 {
		return []net.IPAddr{}, errs[0]
	}

	proxynetutil.SortIPAddrs(ipAddrs, p.Config.PreferIPv6)

	return ipAddrs, nil
}

func appendIPAddrs(ipAddrs *[]net.IPAddr, answers []dns.RR) {
	for _, ans := range answers {
		switch ans := ans.(type) {
		case *dns.A:
			ip := net.IPAddr{IP: ans.A}
			*ipAddrs = append(*ipAddrs, ip)
		case *dns.AAAA:
			ip := net.IPAddr{IP: ans.AAAA}
			*ipAddrs = append(*ipAddrs, ip)
		default:
			continue
		}
	}
}
07070100000050000081A4000000000000000000000001650C592100000428000000000000000000000000000000000000002500000000dnsproxy-0.55.0/proxy/lookup_test.gopackage proxy

import (
	"net"
	"testing"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestLookupIPAddr(t *testing.T) {
	// Create a simple proxy
	p := Proxy{}
	upstreams := make([]upstream.Upstream, 0)
	// Use AdGuard DNS here

	dnsUpstream, err := upstream.AddressToUpstream("94.140.14.14", &upstream.Options{
		Timeout: defaultTimeout,
	})
	require.NoError(t, err)

	p.UpstreamConfig = &UpstreamConfig{}
	p.UpstreamConfig.Upstreams = append(upstreams, dnsUpstream)

	// Init the proxy
	err = p.Init()
	require.NoError(t, err)

	// Now let's try doing some lookups
	addrs, err := p.LookupIPAddr("dns.google")
	require.NoError(t, err)
	require.NotEmpty(t, addrs)

	assert.Contains(t, addrs, net.IPAddr{IP: net.IP{8, 8, 8, 8}})
	assert.Contains(t, addrs, net.IPAddr{IP: net.IP{8, 8, 4, 4}})
	if len(addrs) > 2 {
		assert.Contains(t, addrs, net.IPAddr{IP: net.ParseIP("2001:4860:4860::8888")})
		assert.Contains(t, addrs, net.IPAddr{IP: net.ParseIP("2001:4860:4860::8844")})
	}
}
07070100000051000081A4000000000000000000000001650C5921000006BC000000000000000000000000000000000000002C00000000dnsproxy-0.55.0/proxy/optimisticresolver.gopackage proxy

import (
	"encoding/hex"
	"sync"

	"github.com/AdguardTeam/golibs/log"
)

// cachingResolver is the DNS resolver that is also able to cache responses.
type cachingResolver interface {
	// replyFromUpstream returns true if the request from dctx is successfully
	// resolved and the response may be cached.
	//
	// TODO(e.burkov):  Find out when ok can be false with nil err.
	replyFromUpstream(dctx *DNSContext) (ok bool, err error)

	// cacheResp caches the response from dctx.
	cacheResp(dctx *DNSContext)
}

// type check
var _ cachingResolver = (*Proxy)(nil)

// optimisticResolver is used to eventually resolve expired cached requests.
type optimisticResolver struct {
	reqs *sync.Map
	cr   cachingResolver
}

// newOptimisticResolver returns the new resolver for expired cached requests.
// cr must not be nil.
func newOptimisticResolver(cr cachingResolver) (s *optimisticResolver) {
	return &optimisticResolver{
		reqs: &sync.Map{},
		cr:   cr,
	}
}

// unit is a convenient alias for struct{}.
type unit = struct{}

// ResolveOnce tries to resolve the request from dctx but only a single request
// with the same key at the same period of time.  It runs in a separate
// goroutine.  Do not pass the *DNSContext which is used elsewhere since it
// isn't intended to be used concurrently.
func (s *optimisticResolver) ResolveOnce(dctx *DNSContext, key []byte) {
	defer log.OnPanic("optimistic resolver")

	keyHexed := hex.EncodeToString(key)
	if _, ok := s.reqs.LoadOrStore(keyHexed, unit{}); ok {
		return
	}
	defer s.reqs.Delete(keyHexed)

	ok, err := s.cr.replyFromUpstream(dctx)
	if err != nil {
		log.Debug("resolving request for optimistic cache: %s", err)
	}

	if ok {
		s.cr.cacheResp(dctx)
	}
}
07070100000052000081A4000000000000000000000001650C592100000B53000000000000000000000000000000000000003100000000dnsproxy-0.55.0/proxy/optimisticresolver_test.gopackage proxy

import (
	"bytes"
	"sync"
	"testing"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/stretchr/testify/assert"
)

// testCachingResolver is a stub implementation of the cachingResolver interface
// to simplify testing.
type testCachingResolver struct {
	onReplyFromUpstream func(dctx *DNSContext) (ok bool, err error)
	onCacheResp         func(dctx *DNSContext)
}

// replyFromUpstream implements the cachingResolver interface for
// *testCachingResolver.
func (tcr *testCachingResolver) replyFromUpstream(dctx *DNSContext) (ok bool, err error) {
	return tcr.onReplyFromUpstream(dctx)
}

// cacheResp implements the cachingResolver interface for *testCachingResolver.
func (tcr *testCachingResolver) cacheResp(dctx *DNSContext) {
	tcr.onCacheResp(dctx)
}

func TestOptimisticResolver_ResolveOnce(t *testing.T) {
	in, out := make(chan unit), make(chan unit)
	var timesResolved, timesSet int

	tcr := &testCachingResolver{
		onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) {
			timesResolved++

			return true, nil
		},
		onCacheResp: func(_ *DNSContext) {
			timesSet++

			// Pass the signal to begin running secondary goroutines.
			out <- unit{}
			// Block until all the secondary goroutines finish.
			<-in
		},
	}

	s := newOptimisticResolver(tcr)
	sameKey := []byte{1, 2, 3}

	// Start the primary goroutine.
	go s.ResolveOnce(nil, sameKey)
	// Block until the primary goroutine reaches the resolve function.
	<-out

	wg := &sync.WaitGroup{}

	const secondaryNum = 10
	wg.Add(secondaryNum)
	for i := 0; i < secondaryNum; i++ {
		go func() {
			defer wg.Done()

			s.ResolveOnce(nil, sameKey)
		}()
	}

	// Wait until all the secondary goroutines are finished.
	wg.Wait()
	// Pass the signal to terminate the primary goroutine.
	in <- unit{}

	assert.Equal(t, 1, timesResolved)
	assert.Equal(t, 1, timesSet)
}

func TestOptimisticResolver_ResolveOnce_unsuccessful(t *testing.T) {
	key := []byte{1, 2, 3}

	t.Run("error", func(t *testing.T) {
		logOutput := &bytes.Buffer{}

		prevLevel := log.GetLevel()
		prevOutput := log.Writer()
		log.SetLevel(log.DEBUG)
		log.SetOutput(logOutput)
		t.Cleanup(func() {
			log.SetLevel(prevLevel)
			log.SetOutput(prevOutput)
		})

		const rerr errors.Error = "sample resolving error"
		s := newOptimisticResolver(&testCachingResolver{
			onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return true, rerr },
			onCacheResp:         func(_ *DNSContext) {},
		})
		s.ResolveOnce(nil, key)

		assert.Contains(t, logOutput.String(), rerr.Error())
	})

	t.Run("not_ok", func(t *testing.T) {
		cached := false
		s := newOptimisticResolver(&testCachingResolver{
			onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return false, nil },
			onCacheResp:         func(_ *DNSContext) { cached = true },
		})
		s.ResolveOnce(nil, key)

		assert.False(t, cached)
	})
}
07070100000053000081A4000000000000000000000001650C592100004221000000000000000000000000000000000000001F00000000dnsproxy-0.55.0/proxy/proxy.go// Package proxy implements a DNS proxy that supports all known DNS
// encryption protocols.
package proxy

import (
	"context"
	"fmt"
	"io"
	"net"
	"net/http"
	"net/netip"
	"sync"
	"sync/atomic"
	"time"

	"github.com/AdguardTeam/dnsproxy/fastip"
	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/miekg/dns"
	gocache "github.com/patrickmn/go-cache"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
)

const (
	defaultTimeout   = 10 * time.Second
	minDNSPacketSize = 12 + 5
)

// Proto is the DNS protocol.
type Proto string

// Proto values.
const (
	// ProtoUDP is the plain DNS-over-UDP protocol.
	ProtoUDP Proto = "udp"
	// ProtoTCP is the plain DNS-over-TCP protocol.
	ProtoTCP Proto = "tcp"
	// ProtoTLS is the DNS-over-TLS (DoT) protocol.
	ProtoTLS Proto = "tls"
	// ProtoHTTPS is the DNS-over-HTTPS (DoH) protocol.
	ProtoHTTPS Proto = "https"
	// ProtoQUIC is the DNS-over-QUIC (DoQ) protocol.
	ProtoQUIC Proto = "quic"
	// ProtoDNSCrypt is the DNSCrypt protocol.
	ProtoDNSCrypt Proto = "dnscrypt"
)

const (
	// UnqualifiedNames is reserved name for "unqualified names only", ie names without dots
	UnqualifiedNames = "unqualified_names"
)

// Proxy combines the proxy server state and configuration
//
// TODO(a.garipov): Consider extracting conf blocks for better fieldalignment.
type Proxy struct {
	// counter is the counter of messages.  It must only be incremented
	// atomically, so it must be the first member of the struct to make sure
	// that it has a 64-bit alignment.
	//
	// See https://golang.org/pkg/sync/atomic/#pkg-note-BUG.
	counter uint64

	// started indicates if the proxy has been started.
	started bool

	// Listeners
	// --

	// udpListen are the listened UDP connections.
	udpListen []*net.UDPConn

	// tcpListen are the listened TCP connections.
	tcpListen []net.Listener

	// tlsListen are the listened TCP connections with TLS.
	tlsListen []net.Listener

	// quicListen are the listened QUIC connections.
	quicListen []*quic.EarlyListener

	// httpsListen are the listened HTTPS connections.
	httpsListen []net.Listener

	// h3Listen are the listened HTTP/3 connections.
	h3Listen []*quic.EarlyListener

	// httpsServer serves queries received over HTTPS.
	httpsServer *http.Server

	// h3Server serves queries received over HTTP/3.
	h3Server *http3.Server

	// dnsCryptUDPListen are the listened UDP connections for DNSCrypt.
	dnsCryptUDPListen []*net.UDPConn

	// dnsCryptTCPListen are the listened TCP connections for DNSCrypt.
	dnsCryptTCPListen []net.Listener

	// dnsCryptServer serves DNSCrypt queries.
	dnsCryptServer *dnscrypt.Server

	// Upstream
	// --

	// upstreamRttStats is a map of upstream addresses and their rtt.  Used to
	// sort upstreams by their latency.
	upstreamRttStats map[string]int

	// rttLock protects upstreamRttStats.
	rttLock sync.Mutex

	// DNS64 (in case dnsproxy works in a NAT64/DNS64 network)
	// --

	// dns64Prefs is a set of NAT64 prefixes that are used to detect and
	// construct DNS64 responses.  The DNS64 function is disabled if it is
	// empty.
	dns64Prefs []netip.Prefix

	// Ratelimit
	// --

	// ratelimitBuckets is a storage for ratelimiters for individual IPs.
	ratelimitBuckets *gocache.Cache

	// ratelimitLock protects ratelimitBuckets.
	ratelimitLock sync.Mutex

	// proxyVerifier checks if the proxy is in the trusted list.
	proxyVerifier netutil.SubnetSet

	// DNS cache
	// --

	// cache is used to cache requests.  It is disabled if nil.
	cache *cache

	// shortFlighter is used to resolve the expired cached requests without
	// repetitions.
	shortFlighter *optimisticResolver

	// FastestAddr module
	// --

	// fastestAddr finds the fastest IP address for the resolved domain.
	fastestAddr *fastip.FastestAddr

	// Other
	// --

	// bytesPool is a pool of byte slices used to read DNS packets.
	bytesPool *sync.Pool

	// udpOOBSize is the size of the out-of-band data for UDP connections.
	udpOOBSize int

	// RWMutex protects the whole proxy.
	sync.RWMutex

	// requestGoroutinesSema limits the number of simultaneous requests.
	//
	// TODO(a.garipov): Currently we have to pass this exact semaphore to
	// the workers, to prevent races on restart.  In the future we will need
	// a better restarting mechanism that completely prevents such invalid
	// states.
	//
	// See also: https://github.com/AdguardTeam/AdGuardHome/issues/2242.
	requestGoroutinesSema semaphore

	// Config is the proxy configuration.
	Config
}

// Init populates fields of p but does not start listeners.
func (p *Proxy) Init() (err error) {
	p.initCache()

	if p.MaxGoroutines > 0 {
		log.Info("dnsproxy: max goroutines is set to %d", p.MaxGoroutines)

		p.requestGoroutinesSema, err = newChanSemaphore(p.MaxGoroutines)
		if err != nil {
			return fmt.Errorf("can't init semaphore: %w", err)
		}
	} else {
		p.requestGoroutinesSema = newNoopSemaphore()
	}

	p.udpOOBSize = proxynetutil.UDPGetOOBSize()
	p.bytesPool = &sync.Pool{
		New: func() interface{} {
			// 2 bytes may be used to store packet length (see TCP/TLS)
			b := make([]byte, 2+dns.MaxMsgSize)

			return &b
		},
	}

	if p.UpstreamMode == UModeFastestAddr {
		log.Info("dnsproxy: fastest ip is enabled")

		p.fastestAddr = fastip.NewFastestAddr()
		if timeout := p.FastestPingTimeout; timeout > 0 {
			p.fastestAddr.PingWaitTimeout = timeout
		}
	}

	var trusted []*net.IPNet
	trusted, err = netutil.ParseSubnets(p.TrustedProxies...)
	if err != nil {
		return fmt.Errorf("initializing subnet detector for proxies verifying: %w", err)
	}

	p.proxyVerifier = netutil.SliceSubnetSet(trusted)

	err = p.setupDNS64()
	if err != nil {
		return fmt.Errorf("setting up DNS64: %w", err)
	}

	return nil
}

// Start initializes the proxy server and starts listening
func (p *Proxy) Start() (err error) {
	log.Info("dnsproxy: starting dns proxy server")

	p.Lock()
	defer p.Unlock()

	if p.started {
		return errors.Error("server has been already started")
	}

	err = p.validateConfig()
	if err != nil {
		return err
	}

	err = p.Init()
	if err != nil {
		return err
	}

	// TODO(a.garipov): Accept a context into this method.
	ctx := context.Background()
	err = p.startListeners(ctx)
	if err != nil {
		return fmt.Errorf("starting listeners: %w", err)
	}

	p.started = true

	return nil
}

// closeAll closes all closers and appends the occurred errors to errs.
func closeAll[C io.Closer](errs []error, closers ...C) (appended []error) {
	for _, c := range closers {
		err := c.Close()
		if err != nil {
			errs = append(errs, err)
		}
	}

	return errs
}

// Stop stops the proxy server including all its listeners
func (p *Proxy) Stop() error {
	log.Info("dnsproxy: stopping dns proxy server")

	p.Lock()
	defer p.Unlock()

	if !p.started {
		log.Info("dnsproxy: dns proxy server is not started")

		return nil
	}

	errs := closeAll(nil, p.tcpListen...)
	p.tcpListen = nil

	errs = closeAll(errs, p.udpListen...)
	p.udpListen = nil

	errs = closeAll(errs, p.tlsListen...)
	p.tlsListen = nil

	if p.httpsServer != nil {
		errs = closeAll(errs, p.httpsServer)
		p.httpsServer = nil

		// No need to close these since they're closed by httpsServer.Close().
		p.httpsListen = nil
	}

	if p.h3Server != nil {
		errs = closeAll(errs, p.h3Server)
		p.h3Server = nil
	}

	errs = closeAll(errs, p.h3Listen...)
	p.h3Listen = nil

	errs = closeAll(errs, p.quicListen...)
	p.quicListen = nil

	errs = closeAll(errs, p.dnsCryptUDPListen...)
	p.dnsCryptUDPListen = nil

	errs = closeAll(errs, p.dnsCryptTCPListen...)
	p.dnsCryptTCPListen = nil

	for _, u := range []*UpstreamConfig{
		p.UpstreamConfig,
		p.PrivateRDNSUpstreamConfig,
		p.Fallbacks,
	} {
		if u != nil {
			errs = closeAll(errs, u)
		}
	}

	p.started = false

	log.Println("dnsproxy: stopped dns proxy server")

	if len(errs) > 0 {
		return errors.List("stopping dns proxy server", errs...)
	}

	return nil
}

// Addrs returns all listen addresses for the specified proto or nil if the proxy does not listen to it.
// proto must be "tcp", "tls", "https", "quic", or "udp"
func (p *Proxy) Addrs(proto Proto) []net.Addr {
	p.RLock()
	defer p.RUnlock()

	var addrs []net.Addr

	switch proto {
	case ProtoTCP:
		for _, l := range p.tcpListen {
			addrs = append(addrs, l.Addr())
		}

	case ProtoTLS:
		for _, l := range p.tlsListen {
			addrs = append(addrs, l.Addr())
		}

	case ProtoHTTPS:
		for _, l := range p.httpsListen {
			addrs = append(addrs, l.Addr())
		}

	case ProtoUDP:
		for _, l := range p.udpListen {
			addrs = append(addrs, l.LocalAddr())
		}

	case ProtoQUIC:
		for _, l := range p.quicListen {
			addrs = append(addrs, l.Addr())
		}

	case ProtoDNSCrypt:
		// Using only UDP addrs here
		// TODO: to do it better we should either do ProtoDNSCryptTCP/ProtoDNSCryptUDP
		// or we should change the configuration so that it was not possible to
		// set different ports for TCP/UDP listeners.
		for _, l := range p.dnsCryptUDPListen {
			addrs = append(addrs, l.LocalAddr())
		}

	default:
		panic("proto must be 'tcp', 'tls', 'https', 'quic', 'dnscrypt' or 'udp'")
	}

	return addrs
}

// Addr returns the first listen address for the specified proto or null if the proxy does not listen to it
// proto must be "tcp", "tls", "https", "quic", or "udp"
func (p *Proxy) Addr(proto Proto) net.Addr {
	p.RLock()
	defer p.RUnlock()
	switch proto {
	case ProtoTCP:
		if len(p.tcpListen) == 0 {
			return nil
		}
		return p.tcpListen[0].Addr()

	case ProtoTLS:
		if len(p.tlsListen) == 0 {
			return nil
		}
		return p.tlsListen[0].Addr()

	case ProtoHTTPS:
		if len(p.httpsListen) == 0 {
			return nil
		}
		return p.httpsListen[0].Addr()

	case ProtoUDP:
		if len(p.udpListen) == 0 {
			return nil
		}
		return p.udpListen[0].LocalAddr()

	case ProtoQUIC:
		if len(p.quicListen) == 0 {
			return nil
		}
		return p.quicListen[0].Addr()

	case ProtoDNSCrypt:
		if len(p.dnsCryptUDPListen) == 0 {
			return nil
		}
		return p.dnsCryptUDPListen[0].LocalAddr()
	default:
		panic("proto must be 'tcp', 'tls', 'https', 'quic', 'dnscrypt' or 'udp'")
	}
}

// needsLocalUpstream returns true if the request should be handled by a private
// upstream servers.
func (p *Proxy) needsLocalUpstream(req *dns.Msg) (ok bool) {
	if req.Question[0].Qtype != dns.TypePTR {
		return false
	}

	host := req.Question[0].Name
	ip, err := netutil.IPFromReversedAddr(host)
	if err != nil {
		log.Debug("dnsproxy: failed to parse ip from ptr request: %s", err)

		return false
	}

	return p.shouldStripDNS64(ip)
}

// selectUpstreams returns the upstreams to use for the specified host.  It
// firstly considers custom upstreams if those aren't empty and then the
// configured ones.  The returned slice may be empty or nil.
func (p *Proxy) selectUpstreams(d *DNSContext) (upstreams []upstream.Upstream) {
	host := d.Req.Question[0].Name
	if !p.needsLocalUpstream(d.Req) {
		if custom := d.CustomUpstreamConfig; custom != nil {
			// Try to use custom.
			upstreams = custom.getUpstreamsForDomain(host)
			if len(upstreams) > 0 {
				return upstreams
			}
		}

		// Use configured.
		return p.UpstreamConfig.getUpstreamsForDomain(host)
	}

	// Use private upstreams.
	private := p.PrivateRDNSUpstreamConfig
	if private == nil {
		return nil
	}

	ip, _ := netutil.IPAndPortFromAddr(d.Addr)
	// TODO(e.burkov):  Detect against the actual configured subnet set.
	// Perhaps, even much earlier.
	if !netutil.IsLocallyServed(ip) {
		return nil
	}

	return private.getUpstreamsForDomain(host)
}

// replyFromUpstream tries to resolve the request.
func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) {
	req := d.Req

	upstreams := p.selectUpstreams(d)
	if len(upstreams) == 0 {
		return false, fmt.Errorf("selecting general upstream: %w", upstream.ErrNoUpstreams)
	}

	start := time.Now()

	// Perform the DNS request.
	resp, u, err := p.exchange(req, upstreams)
	if dns64Ups := p.performDNS64(req, resp, upstreams); dns64Ups != nil {
		u = dns64Ups
	} else if p.isBogusNXDomain(resp) {
		log.Debug("proxy: replying from upstream: response contains bogus-nxdomain ip")
		resp = p.genWithRCode(req, dns.RcodeNameError)
	}

	log.Debug("proxy: replying from upstream: rtt is %s", time.Since(start))

	if err != nil && p.Fallbacks != nil {
		log.Debug("proxy: replying from upstream: using fallback due to %s", err)

		upstreams = p.Fallbacks.getUpstreamsForDomain(req.Question[0].Name)
		if len(upstreams) == 0 {
			return false, fmt.Errorf("selecting fallback upstream: %w", upstream.ErrNoUpstreams)
		}

		resp, u, err = upstream.ExchangeParallel(upstreams, req)
	}

	p.handleExchangeResult(d, req, resp, u)

	return resp != nil, err
}

// handleExchangeResult handles the result after the upstream exchange.  It sets
// the response to d and sets the upstream that have resolved the request.  If
// the response is nil, it generates a server failure response.
func (p *Proxy) handleExchangeResult(d *DNSContext, req, resp *dns.Msg, u upstream.Upstream) {
	if resp == nil {
		d.Res = p.genServerFailure(req)
		d.hasEDNS0 = false

		return
	}

	d.Upstream = u
	d.Res = resp

	p.setMinMaxTTL(resp)
	if len(req.Question) > 0 && len(resp.Question) == 0 {
		// Explicitly construct the question section since some upstreams may
		// respond with invalidly constructed messages which cause out-of-range
		// panics afterwards.
		//
		// See https://github.com/AdguardTeam/AdGuardHome/issues/3551.
		resp.Question = []dns.Question{req.Question[0]}
	}
}

// addDO adds EDNS0 RR if needed and sets DO bit of msg to true.
func addDO(msg *dns.Msg) {
	if o := msg.IsEdns0(); o != nil {
		if !o.Do() {
			o.SetDo()
		}

		return
	}

	msg.SetEdns0(defaultUDPBufSize, true)
}

// defaultUDPBufSize defines the default size of UDP buffer for EDNS0 RRs.
const defaultUDPBufSize = 2048

// Resolve is the default resolving method used by the DNS proxy to query
// upstream servers.
func (p *Proxy) Resolve(dctx *DNSContext) (err error) {
	if p.EnableEDNSClientSubnet {
		dctx.processECS(p.EDNSAddr)
	}

	dctx.calcFlagsAndSize()

	// Use cache only if it's enabled and the query doesn't use custom upstream.
	// Also don't lookup the cache for responses with DNSSEC checking disabled
	// since only validated responses are cached and those may be not the
	// desired result for user specifying CD flag.
	cacheWorks := p.cacheWorks(dctx)
	if cacheWorks {
		if p.replyFromCache(dctx) {
			// Complete the response from cache.
			dctx.scrub()

			return nil
		}

		// On cache miss request for DNSSEC from the upstream to cache it
		// afterwards.
		addDO(dctx.Req)
	}

	var ok bool
	ok, err = p.replyFromUpstream(dctx)

	// Don't cache the responses having CD flag, just like Dnsmasq does.  It
	// prevents the cache from being poisoned with unvalidated answers which may
	// differ from validated ones.
	//
	// See https://github.com/imp/dnsmasq/blob/770bce967cfc9967273d0acfb3ea018fb7b17522/src/forward.c#L1169-L1172.
	if cacheWorks && ok && !dctx.Res.CheckingDisabled {
		// Cache the response with DNSSEC RRs.
		p.cacheResp(dctx)
	}

	// It is possible that the response is nil if the upstream hasn't been
	// chosen.
	if dctx.Res != nil {
		filterMsg(dctx.Res, dctx.Res, dctx.adBit, dctx.doBit, 0)
	}

	// Complete the response.
	dctx.scrub()

	if p.ResponseHandler != nil {
		p.ResponseHandler(dctx, err)
	}

	return err
}

// cacheWorks returns true if the cache works for the given context.  If not, it
// returns false and logs the reason why.
func (p *Proxy) cacheWorks(dctx *DNSContext) (ok bool) {
	var reason string
	switch {
	case p.cache == nil:
		reason = "disabled"
	case dctx.CustomUpstreamConfig != nil:
		// See https://github.com/AdguardTeam/dnsproxy/issues/169.
		reason = "custom upstreams used"
	case dctx.Req.CheckingDisabled:
		reason = "dnssec check disabled"
	default:
		return true
	}

	log.Debug("dnsproxy: cache: %s; not caching", reason)

	return false
}

// processECS adds EDNS Client Subnet data into the request from d.
func (dctx *DNSContext) processECS(cliIP net.IP) {
	if ecs, _ := ecsFromMsg(dctx.Req); ecs != nil {
		if ones, _ := ecs.Mask.Size(); ones != 0 {
			dctx.ReqECS = ecs

			log.Debug("dnsproxy: passing through ecs: %s", dctx.ReqECS)

			return
		}
	}

	// Set ECS.
	if cliIP == nil {
		cliIP, _ = netutil.IPAndPortFromAddr(dctx.Addr)
		if cliIP == nil {
			return
		}
	}

	if !netutil.IsSpecialPurpose(cliIP) {
		// A Stub Resolver MUST set SCOPE PREFIX-LENGTH to 0.  See RFC 7871
		// Section 6.
		dctx.ReqECS = setECS(dctx.Req, cliIP, 0)

		log.Debug("dnsproxy: setting ecs: %s", dctx.ReqECS)
	}
}

// newDNSContext returns a new properly initialized *DNSContext.
func (p *Proxy) newDNSContext(proto Proto, req *dns.Msg) (d *DNSContext) {
	return &DNSContext{
		Proto:     proto,
		Req:       req,
		StartTime: time.Now(),

		RequestID: atomic.AddUint64(&p.counter, 1),
	}
}
07070100000054000081A4000000000000000000000001650C592100008ADC000000000000000000000000000000000000002400000000dnsproxy-0.55.0/proxy/proxy_test.gopackage proxy

import (
	"crypto/ecdsa"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"io"
	"math/big"
	"net"
	"net/netip"
	"net/url"
	"os"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	glcache "github.com/AdguardTeam/golibs/cache"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestMain(m *testing.M) {
	// Disable logging in tests.
	log.SetOutput(io.Discard)

	os.Exit(m.Run())
}

const (
	listenIP          = "127.0.0.1"
	upstreamAddr      = "8.8.8.8:53"
	tlsServerName     = "testdns.adguard.com"
	testMessagesCount = 10
)

// TestProxyRace sends multiple parallel DNS requests to the
// fully configured dnsproxy to check for race conditions
func TestProxyRace(t *testing.T) {
	// Prepare the proxy server
	dnsProxy := createTestProxy(t, nil)

	// Use the same upstream twice so that we could rotate them
	dnsProxy.UpstreamConfig.Upstreams = append(dnsProxy.UpstreamConfig.Upstreams, dnsProxy.UpstreamConfig.Upstreams[0])

	// Start listening
	err := dnsProxy.Start()
	if err != nil {
		t.Fatalf("cannot start the DNS proxy: %s", err)
	}

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	conn, err := dns.Dial("udp", addr.String())
	if err != nil {
		t.Fatalf("cannot connect to the proxy: %s", err)
	}

	sendTestMessagesAsync(t, conn)

	// Stop the proxy
	err = dnsProxy.Stop()
	if err != nil {
		t.Fatalf("cannot stop the DNS proxy: %s", err)
	}
}

// defaultTestTTL used to guarantee caching.
const defaultTestTTL = 1000

type testDNSSECUpstream struct {
	a     dns.RR
	txt   dns.RR
	ds    dns.RR
	rrsig dns.RR
}

// type check
var _ upstream.Upstream = (*testDNSSECUpstream)(nil)

// Exchange implements the upstream.Upstream interface for *testDNSSECUpstream.
func (u *testDNSSECUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
	resp = &dns.Msg{}
	resp.SetReply(m)

	q := m.Question[0]
	switch q.Qtype {
	case dns.TypeA:
		resp.Answer = append(resp.Answer, u.a)
	case dns.TypeTXT:
		resp.Answer = append(resp.Answer, u.txt)
	case dns.TypeDS:
		resp.Answer = append(resp.Answer, u.ds)
	default:
		// Go on.  The RRSIG resource record is added afterward.  This
		// upstream.Upstream implementation doesn't handle explicit
		// requests for it.
	}

	if len(resp.Answer) > 0 {
		resp.Answer[0].Header().Ttl = defaultTestTTL
	}

	if o := m.IsEdns0(); o != nil {
		resp.Answer = append(resp.Answer, u.rrsig)

		resp.SetEdns0(defaultUDPBufSize, o.Do())
	}

	return resp, nil
}

// Address implements the upstream.Upstream interface for *testDNSSECUpstream.
func (u *testDNSSECUpstream) Address() string {
	return ""
}

// Close implements the upstream.Upstream interface for *testDNSSECUpstream.
func (u *testDNSSECUpstream) Close() (err error) {
	return nil
}

func TestProxy_Resolve_dnssecCache(t *testing.T) {
	const host = "example.com"

	const (
		// Larger than UDP buffer size to invoke truncation.
		txtDataLen      = 1024
		txtDataChunkLen = 255
	)

	txtDataChunkNum := txtDataLen / txtDataChunkLen
	if txtDataLen%txtDataChunkLen > 0 {
		txtDataChunkNum++
	}

	txts := make([]string, txtDataChunkNum)
	randData := make([]byte, txtDataLen)
	n, err := rand.Read(randData)
	require.NoError(t, err)
	require.Equal(t, txtDataLen, n)
	for i, c := range randData {
		randData[i] = c%26 + 'a'
	}
	// *dns.TXT requires splitting the actual data into
	// 256-byte chunks.
	for i := 0; i < txtDataChunkNum; i++ {
		r := txtDataChunkLen * (i + 1)
		if r > txtDataLen {
			r = txtDataLen
		}
		txts[i] = string(randData[txtDataChunkLen*i : r])
	}
	txt := &dns.TXT{
		Hdr: dns.RR_Header{
			Name:   dns.Fqdn(host),
			Rrtype: dns.TypeTXT,
			Class:  dns.ClassINET,
		},
		Txt: txts,
	}

	a := &dns.A{
		Hdr: dns.RR_Header{
			Name:   dns.Fqdn(host),
			Rrtype: dns.TypeA,
			Class:  dns.ClassINET,
		},
		A: net.IP{1, 2, 3, 4},
	}

	ds := &dns.DS{
		Hdr: dns.RR_Header{
			Name:   dns.Fqdn(host),
			Rrtype: dns.TypeDS,
			Class:  dns.ClassINET,
		},
		Digest: "736f6d652064656c65676174696f6e207369676e6572",
	}

	rrsig := &dns.RRSIG{
		Hdr: dns.RR_Header{
			Name:   dns.Fqdn(host),
			Rrtype: dns.TypeRRSIG,
			Class:  dns.ClassINET,
			Ttl:    defaultTestTTL,
		},
		TypeCovered: dns.TypeA,
		Algorithm:   8,
		Labels:      2,
		SignerName:  dns.Fqdn(host),
		Signature:   "c29tZSBycnNpZyByZWxhdGVkIHN0dWZm",
	}

	p := &Proxy{}
	p.UpstreamConfig = &UpstreamConfig{
		Upstreams: []upstream.Upstream{&testDNSSECUpstream{
			a:     a,
			txt:   txt,
			ds:    ds,
			rrsig: rrsig,
		}},
	}
	p.cache = newCache(defaultCacheSize, false, false)

	testCases := []struct {
		wantAns dns.RR
		name    string
		wantLen int
		edns    bool
	}{{
		wantAns: a,
		name:    "a_noedns",
		wantLen: 1,
		edns:    false,
	}, {
		wantAns: a,
		name:    "a_ends",
		wantLen: 2,
		edns:    true,
	}, {
		wantAns: txt,
		name:    "txt_noedns",
		wantLen: 1,
		edns:    false,
	}, {
		wantAns: txt,
		name:    "txt_edns",
		// Truncated.
		wantLen: 0,
		edns:    true,
	}, {
		wantAns: ds,
		name:    "ds_noedns",
		wantLen: 1,
		edns:    false,
	}, {
		wantAns: ds,
		name:    "ds_edns",
		wantLen: 2,
		edns:    true,
	}}

	for _, tc := range testCases {
		req := &dns.Msg{
			MsgHdr: dns.MsgHdr{
				Id: dns.Id(),
			},
			Compress: true,
			Question: []dns.Question{{
				Name:   dns.Fqdn(tc.wantAns.Header().Name),
				Qtype:  tc.wantAns.Header().Rrtype,
				Qclass: tc.wantAns.Header().Class,
			}},
		}
		if tc.edns {
			req.SetEdns0(txtDataLen/2, true)
		}

		dctx := &DNSContext{
			Req:   req,
			Proto: ProtoUDP,
		}

		t.Run(tc.name, func(t *testing.T) {
			t.Cleanup(p.cache.items.Clear)
			err = p.Resolve(dctx)
			require.NoError(t, err)

			res := dctx.Res
			require.NotNil(t, res)

			require.Len(t, res.Answer, tc.wantLen, res.Answer)
			switch tc.wantLen {
			case 0:
				assert.True(t, res.Truncated)
			case 1:
				res.Answer[0].Header().Ttl = defaultTestTTL
				assert.Equal(t, tc.wantAns.String(), res.Answer[0].String())
			case 2:
				res.Answer[0].Header().Ttl = defaultTestTTL
				assert.Equal(t, tc.wantAns.String(), res.Answer[0].String())
				assert.Equal(t, rrsig.String(), res.Answer[1].String())
			default:
				t.Fatalf("wanted length has unexpected value %d", tc.wantLen)
			}

			cached, expired, key := p.cache.get(dctx.Req)
			require.NotNil(t, cached)
			require.Len(t, cached.m.Answer, 2)
			assert.False(t, expired)
			assert.Equal(t, key, msgToKey(dctx.Req))

			// Just make it match.
			cached.m.Answer[0].Header().Ttl = defaultTestTTL
			assert.Equal(t, tc.wantAns.String(), cached.m.Answer[0].String())
			assert.Equal(t, rrsig.String(), cached.m.Answer[1].String())
		})

	}
}

func TestUpstreamsSort(t *testing.T) {
	testProxy := createTestProxy(t, nil)
	upstreams := []upstream.Upstream{}

	// there are 4 upstreams in configuration
	config := []string{"1.2.3.4", "1.1.1.1", "2.3.4.5", "8.8.8.8"}
	for _, u := range config {
		up, err := upstream.AddressToUpstream(u, &upstream.Options{Timeout: 1 * time.Second})
		if err != nil {
			t.Fatalf("Failed to create %s upstream: %s", u, err)
		}
		upstreams = append(upstreams, up)
	}

	// create upstreamRttStats for 3 upstreams
	upstreamRttStats := map[string]int{}
	upstreamRttStats["1.1.1.1:53"] = 10
	upstreamRttStats["2.3.4.5:53"] = 20
	upstreamRttStats["1.2.3.4:53"] = 30
	testProxy.upstreamRttStats = upstreamRttStats

	sortedUpstreams := testProxy.getSortedUpstreams(upstreams)

	// upstream without rtt stats means `zero rtt`; this upstream should be the first one after sorting
	if sortedUpstreams[0].Address() != "8.8.8.8:53" {
		t.Fatalf("wrong sort algorithm!")
	}

	// upstreams with rtt stats should be sorted from fast to slow
	if sortedUpstreams[1].Address() != "1.1.1.1:53" {
		t.Fatalf("wrong sort algorithm!")
	}

	if sortedUpstreams[2].Address() != "2.3.4.5:53" {
		t.Fatalf("wrong sort algorithm!")
	}

	if sortedUpstreams[3].Address() != "1.2.3.4:53" {
		t.Fatalf("wrong sort algorithm!")
	}
}

func TestExchangeWithReservedDomains(t *testing.T) {
	dnsProxy := createTestProxy(t, nil)

	// Upstreams specification. Domains adguard.com and google.ru reserved
	// with fake upstreams, maps.google.ru excluded from dnsmasq.
	upstreams := []string{
		"[/adguard.com/]1.2.3.4",
		"[/google.ru/]2.3.4.5",
		"[/maps.google.ru/]#",
		"1.1.1.1",
	}
	config, err := ParseUpstreamsConfig(
		upstreams,
		&upstream.Options{
			InsecureSkipVerify: false,
			Bootstrap:          []string{"8.8.8.8"},
			Timeout:            1 * time.Second,
		},
	)
	require.NoError(t, err)

	dnsProxy.UpstreamConfig = config

	err = dnsProxy.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	// Create a DNS-over-TCP client connection.
	addr := dnsProxy.Addr(ProtoTCP)
	conn, err := dns.Dial("tcp", addr.String())
	require.NoError(t, err)

	// Create google-a test message.
	req := createTestMessage()
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	// Make sure that dnsproxy is working.
	res, err := conn.ReadMsg()
	require.NoError(t, err)
	requireResponse(t, req, res)

	// Create adguard.com test message.
	req = createHostTestMessage("adguard.com")
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	// Test message should not be resolved.
	res, _ = conn.ReadMsg()
	require.Nil(t, res.Answer)

	// Create www.google.ru test message.
	req = createHostTestMessage("www.google.ru")
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	// Test message should not be resolved.
	res, _ = conn.ReadMsg()
	require.Nil(t, res.Answer)

	// Create maps.google.ru test message.
	req = createHostTestMessage("maps.google.ru")
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	// Test message should be resolved.
	res, _ = conn.ReadMsg()
	require.NotNil(t, res.Answer)
}

// TestOneByOneUpstreamsExchange tries to resolve DNS request
// with one valid and two invalid upstreams
func TestOneByOneUpstreamsExchange(t *testing.T) {
	timeOut := 1 * time.Second
	dnsProxy := createTestProxy(t, nil)

	// invalid fallback to make sure that reply is not coming from fallback
	// server
	var err error
	dnsProxy.Fallbacks, err = ParseUpstreamsConfig(
		[]string{"1.2.3.4:567"},
		&upstream.Options{Timeout: timeOut},
	)
	require.NoError(t, err)

	// add one valid and two invalid upstreams
	upstreams := []string{"https://fake-dns.com/fake-dns-query", "tls://fake-dns.com", "1.1.1.1"}
	dnsProxy.UpstreamConfig.Upstreams = []upstream.Upstream{}
	for _, line := range upstreams {
		var u upstream.Upstream
		u, err = upstream.AddressToUpstream(
			line,
			&upstream.Options{
				Bootstrap: []string{"8.8.8.8:53"},
				Timeout:   timeOut,
			},
		)
		require.NoError(t, err)

		dnsProxy.UpstreamConfig.Upstreams = append(dnsProxy.UpstreamConfig.Upstreams, u)
	}

	err = dnsProxy.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	// create a DNS-over-TCP client connection
	addr := dnsProxy.Addr(ProtoTCP)
	conn, err := dns.Dial("tcp", addr.String())
	require.NoError(t, err)

	// make sure that the response is okay and resolved by valid upstream
	req := createTestMessage()
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	start := time.Now()
	res, err := conn.ReadMsg()
	require.NoError(t, err)
	requireResponse(t, req, res)

	elapsed := time.Since(start)
	if elapsed > 3*timeOut {
		t.Fatalf("the operation took much more time than the configured timeout")
	}
}

// newLocalUpstreamListener creates a new localhost listener on the specified
// port for tcp4 network and returns its listening address.
func newLocalUpstreamListener(t *testing.T, port uint16, h dns.Handler) (real netip.AddrPort) {
	t.Helper()

	startCh := make(chan struct{})
	upsSrv := &dns.Server{
		Addr:              netip.AddrPortFrom(netutil.IPv4Localhost(), port).String(),
		Net:               "tcp",
		Handler:           h,
		NotifyStartedFunc: func() { close(startCh) },
	}
	go func() {
		err := upsSrv.ListenAndServe()
		require.NoError(testutil.PanicT{}, err)
	}()

	<-startCh
	testutil.CleanupAndRequireSuccess(t, upsSrv.Shutdown)

	return testutil.RequireTypeAssert[*net.TCPAddr](t, upsSrv.Listener.Addr()).AddrPort()
}

func TestFallback(t *testing.T) {
	responseCh := make(chan uint16)
	failCh := make(chan uint16)

	const timeout = 1 * time.Second

	successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
		testutil.RequireSend(testutil.PanicT{}, responseCh, r.Id, timeout)

		require.NoError(testutil.PanicT{}, w.WriteMsg((&dns.Msg{}).SetReply(r)))
	})
	failHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
		testutil.RequireSend(testutil.PanicT{}, failCh, r.Id, timeout)

		require.NoError(testutil.PanicT{}, w.WriteMsg(&dns.Msg{}))
	})

	successAddr := (&url.URL{
		Scheme: string(ProtoTCP),
		Host:   newLocalUpstreamListener(t, 0, successHandler).String(),
	}).String()
	alsoSuccessAddr := (&url.URL{
		Scheme: string(ProtoTCP),
		Host:   newLocalUpstreamListener(t, 0, successHandler).String(),
	}).String()
	failAddr := (&url.URL{
		Scheme: string(ProtoTCP),
		Host:   newLocalUpstreamListener(t, 0, failHandler).String(),
	}).String()

	dnsProxy := createTestProxy(t, nil)

	var err error
	dnsProxy.UpstreamConfig, err = ParseUpstreamsConfig(
		[]string{
			failAddr,
			"[/specific.example/]" + alsoSuccessAddr,
			// almost.failing.example will fall here first.
			"[/failing.example/]" + failAddr,
		},
		&upstream.Options{Timeout: timeout},
	)
	require.NoError(t, err)

	dnsProxy.Fallbacks, err = ParseUpstreamsConfig(
		[]string{
			failAddr,
			successAddr,
			"[/failing.example/]" + failAddr,
			"[/almost.failing.example/]" + alsoSuccessAddr,
		},
		&upstream.Options{Timeout: timeout},
	)
	require.NoError(t, err)

	err = dnsProxy.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	conn, err := dns.Dial("tcp", dnsProxy.Addr(ProtoTCP).String())
	require.NoError(t, err)

	testCases := []struct {
		name        string
		wantSignals []chan uint16
	}{{
		name: "general.example",
		wantSignals: []chan uint16{
			failCh,
			// Both non-specific fallbacks tried.
			failCh,
			responseCh,
		},
	}, {
		name: "specific.example",
		wantSignals: []chan uint16{
			responseCh,
		},
	}, {
		name: "failing.example",
		wantSignals: []chan uint16{
			failCh,
			failCh,
		},
	}, {
		name: "almost.failing.example",
		wantSignals: []chan uint16{
			failCh,
			responseCh,
		},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			req := createHostTestMessage(tc.name)
			err = conn.WriteMsg(req)
			require.NoError(t, err)

			for _, ch := range tc.wantSignals {
				reqID, ok := testutil.RequireReceive(testutil.PanicT{}, ch, timeout)
				require.True(t, ok)

				assert.Equal(t, req.Id, reqID)
			}

			_, err = conn.ReadMsg()
			require.NoError(t, err)
		})
	}
}

func TestFallbackFromInvalidBootstrap(t *testing.T) {
	timeout := 1 * time.Second
	// Prepare the proxy server
	dnsProxy := createTestProxy(t, nil)

	// List of fallback server addresses. Both are valid
	var err error
	dnsProxy.Fallbacks, err = ParseUpstreamsConfig(
		[]string{"1.0.0.1", "8.8.8.8"},
		&upstream.Options{Timeout: timeout},
	)
	require.NoError(t, err)

	// Using a DoT server with invalid bootstrap.
	u, _ := upstream.AddressToUpstream(
		"tls://dns.adguard.com",
		&upstream.Options{
			Bootstrap: []string{"8.8.8.8:555"},
			Timeout:   timeout,
		},
	)
	dnsProxy.UpstreamConfig.Upstreams = []upstream.Upstream{}
	dnsProxy.UpstreamConfig.Upstreams = append(dnsProxy.UpstreamConfig.Upstreams, u)

	// Start listening
	err = dnsProxy.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	conn, err := dns.Dial("udp", addr.String())
	require.NoError(t, err)

	// Make sure that the response is okay and resolved by the fallback
	req := createTestMessage()
	err = conn.WriteMsg(req)
	require.NoError(t, err)

	start := time.Now()
	res, err := conn.ReadMsg()
	require.NoError(t, err)
	requireResponse(t, req, res)

	elapsed := time.Since(start)
	if elapsed > 3*timeout {
		t.Fatalf("the operation took much more time than the configured timeout")
	}
}

func TestRefuseAny(t *testing.T) {
	// Prepare the proxy server
	dnsProxy := createTestProxy(t, nil)
	dnsProxy.RefuseAny = true

	// Start listening
	err := dnsProxy.Start()
	if err != nil {
		t.Fatalf("cannot start the DNS proxy: %s", err)
	}

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{Net: "udp", Timeout: 500 * time.Millisecond}

	// Create a DNS request
	request := dns.Msg{}
	request.Id = dns.Id()
	request.RecursionDesired = true
	request.SetQuestion("google.com.", dns.TypeANY)

	r, _, err := client.Exchange(&request, addr.String())
	if err != nil {
		t.Fatalf("error in the first request: %s", err)
	}

	if r.Rcode != dns.RcodeNotImplemented {
		t.Fatalf("wrong response code (must've been NotImpl)")
	}

	// Stop the proxy
	err = dnsProxy.Stop()
	if err != nil {
		t.Fatalf("cannot stop the DNS proxy: %s", err)
	}
}

func TestInvalidDNSRequest(t *testing.T) {
	// Prepare the proxy server
	dnsProxy := createTestProxy(t, nil)
	dnsProxy.RefuseAny = true

	// Start listening
	err := dnsProxy.Start()
	if err != nil {
		t.Fatalf("cannot start the DNS proxy: %s", err)
	}

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{Net: "udp", Timeout: 500 * time.Millisecond}

	// Create a DNS request
	request := dns.Msg{}
	request.Id = dns.Id()
	request.RecursionDesired = true

	r, _, err := client.Exchange(&request, addr.String())
	if err != nil {
		t.Fatalf("error in the first request: %s", err)
	}

	if r.Rcode != dns.RcodeServerFailure {
		t.Fatalf("wrong response code (must've been ServerFailure)")
	}

	// Stop the proxy
	err = dnsProxy.Stop()
	if err != nil {
		t.Fatalf("cannot stop the DNS proxy: %s", err)
	}
}

// Server must drop incoming Response messages
func TestResponseInRequest(t *testing.T) {
	dnsProxy := createTestProxy(t, nil)
	err := dnsProxy.Start()
	assert.Nil(t, err)

	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{Net: "udp", Timeout: 500 * time.Millisecond}

	req := createTestMessage()
	req.Response = true

	r, _, err := client.Exchange(req, addr.String())
	assert.NotNil(t, err)
	assert.Nil(t, r)

	_ = dnsProxy.Stop()
}

// Server must respond with SERVFAIL to requests without a Question
func TestNoQuestion(t *testing.T) {
	dnsProxy := createTestProxy(t, nil)
	require.NoError(t, dnsProxy.Start())
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{Net: "udp", Timeout: 500 * time.Millisecond}

	req := createTestMessage()
	req.Question = nil

	r, _, err := client.Exchange(req, addr.String())
	require.NoError(t, err)

	assert.Equal(t, dns.RcodeServerFailure, r.Rcode)
}

// funcUpstream is a mock upstream implementation to simplify testing.  It
// allows assigning custom Exchange and Address methods.
type funcUpstream struct {
	exchangeFunc func(m *dns.Msg) (resp *dns.Msg, err error)
	addressFunc  func() (addr string)
}

// type check
var _ upstream.Upstream = (*funcUpstream)(nil)

// Exchange implements upstream.Upstream interface for *funcUpstream.
func (wu *funcUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) {
	if wu.exchangeFunc == nil {
		return nil, nil
	}

	return wu.exchangeFunc(m)
}

// Address implements upstream.Upstream interface for *funcUpstream.
func (wu *funcUpstream) Address() (addr string) {
	if wu.addressFunc == nil {
		return "stub"
	}

	return wu.addressFunc()
}

// Close implements upstream.Upstream interface for *funcUpstream.
func (wu *funcUpstream) Close() (err error) {
	return nil
}

func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) {
	dnsProxy := createTestProxy(t, nil)
	require.NoError(t, dnsProxy.Start())
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	exchangeFunc := func(m *dns.Msg) (resp *dns.Msg, err error) {
		resp = &dns.Msg{}
		resp.SetReply(m)
		hdr := dns.RR_Header{
			Name:   m.Question[0].Name,
			Class:  dns.ClassINET,
			Rrtype: dns.TypeA,
		}
		resp.Answer = append(resp.Answer, &dns.A{
			Hdr: hdr,
			A:   net.IP{1, 2, 3, 4},
		})
		// Make the response invalid.
		resp.Question = []dns.Question{}

		return resp, nil
	}
	u := &funcUpstream{
		exchangeFunc: exchangeFunc,
	}

	d := &DNSContext{
		CustomUpstreamConfig: &UpstreamConfig{Upstreams: []upstream.Upstream{u}},
		Req:                  createHostTestMessage("host"),
		Addr: &net.TCPAddr{
			IP: net.IP{1, 2, 3, 0},
		},
	}

	var err error
	require.NotPanics(t, func() {
		err = dnsProxy.Resolve(d)
	})
	require.NoError(t, err)

	assert.Equal(t, d.Req.Question[0], d.Res.Question[0])
}

func TestExchangeCustomUpstreamConfig(t *testing.T) {
	prx := createTestProxy(t, nil)
	err := prx.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, prx.Stop)

	ansIP := net.IP{4, 3, 2, 1}
	u := testUpstream{
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Rrtype: dns.TypeA,
				Name:   "host.",
				Ttl:    60,
			},
			A: ansIP,
		}},
	}

	d := DNSContext{
		CustomUpstreamConfig: &UpstreamConfig{Upstreams: []upstream.Upstream{&u}},
		Req:                  createHostTestMessage("host"),
		Addr:                 &net.TCPAddr{IP: net.IP{1, 2, 3, 0}},
	}

	err = prx.Resolve(&d)
	require.NoError(t, err)

	assert.Equal(t, ansIP, getIPFromResponse(d.Res))
}

func TestECS(t *testing.T) {
	t.Run("ipv4", func(t *testing.T) {
		ip := net.IP{1, 2, 3, 4}

		m := &dns.Msg{}
		subnet := setECS(m, ip, 16)

		ones, _ := subnet.Mask.Size()
		assert.Equal(t, 24, ones)

		var scope int
		subnet, scope = ecsFromMsg(m)
		assert.Equal(t, ip.Mask(subnet.Mask), subnet.IP)

		ones, _ = subnet.Mask.Size()
		assert.Equal(t, 24, ones)
		assert.Equal(t, 16, scope)
	})

	t.Run("ipv6", func(t *testing.T) {
		ip := net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}

		m := &dns.Msg{}
		subnet := setECS(m, ip, 48)

		ones, _ := subnet.Mask.Size()
		assert.Equal(t, 56, ones)

		var scope int
		subnet, scope = ecsFromMsg(m)
		assert.Equal(t, ip.Mask(subnet.Mask), subnet.IP)

		ones, _ = subnet.Mask.Size()
		assert.Equal(t, 56, ones)
		assert.Equal(t, 48, scope)
	})
}

// Resolve the same host with the different client subnet values
func TestECSProxy(t *testing.T) {
	prx := createTestProxy(t, nil)
	prx.EnableEDNSClientSubnet = true
	prx.CacheEnabled = true

	var (
		ip1230 = net.IP{1, 2, 3, 0}
		ip2230 = net.IP{2, 2, 3, 0}
		ip4321 = net.IP{4, 3, 2, 1}
		ip4322 = net.IP{4, 3, 2, 2}
		ip4323 = net.IP{4, 3, 2, 3}
	)

	u := testUpstream{
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60},
			A:   ip4321,
		}},
		ecsIP: ip1230,
	}
	prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u}
	err := prx.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, prx.Stop)

	t.Run("cache_subnet", func(t *testing.T) {
		d := DNSContext{
			Req:  createHostTestMessage("host"),
			Addr: &net.TCPAddr{IP: ip1230},
		}

		err = prx.Resolve(&d)
		require.NoError(t, err)

		assert.Equal(t, net.IP{4, 3, 2, 1}, getIPFromResponse(d.Res))
		assert.Equal(t, ip1230, u.ecsReqIP)
	})

	t.Run("serve_subnet_cache", func(t *testing.T) {
		d := DNSContext{
			Req:  createHostTestMessage("host"),
			Addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 1}},
		}
		u.ans, u.ecsIP, u.ecsReqIP = nil, nil, nil

		err = prx.Resolve(&d)
		require.NoError(t, err)

		assert.Equal(t, ip4321, getIPFromResponse(d.Res))
		assert.Nil(t, u.ecsReqIP)
	})

	t.Run("another_subnet", func(t *testing.T) {
		d := DNSContext{
			Req:  createHostTestMessage("host"),
			Addr: &net.TCPAddr{IP: ip2230},
		}
		u.ans = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60},
			A:   ip4322,
		}}
		u.ecsIP = ip2230

		err = prx.Resolve(&d)
		require.NoError(t, err)

		assert.Equal(t, ip4322, getIPFromResponse(d.Res))
		assert.Equal(t, ip2230, u.ecsReqIP)
	})

	t.Run("cache_general", func(t *testing.T) {
		d := DNSContext{
			Req:  createHostTestMessage("host"),
			Addr: &net.TCPAddr{IP: net.IP{127, 0, 0, 1}},
		}
		u.ans = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 60},
			A:   ip4323,
		}}
		u.ecsIP, u.ecsReqIP = nil, nil

		err = prx.Resolve(&d)
		require.NoError(t, err)

		assert.Equal(t, ip4323, getIPFromResponse(d.Res))
		assert.Nil(t, u.ecsReqIP)
	})

	t.Run("serve_general_cache", func(t *testing.T) {
		d := DNSContext{
			Req:  createHostTestMessage("host"),
			Addr: &net.TCPAddr{IP: net.IP{127, 0, 0, 2}},
		}
		u.ans, u.ecsIP, u.ecsReqIP = nil, nil, nil

		err = prx.Resolve(&d)
		require.NoError(t, err)

		assert.Equal(t, ip4323, getIPFromResponse(d.Res))
		assert.Nil(t, u.ecsReqIP)
	})
}

func TestECSProxyCacheMinMaxTTL(t *testing.T) {
	clientIP := net.IP{1, 2, 3, 0}

	prx := createTestProxy(t, nil)
	prx.EnableEDNSClientSubnet = true
	prx.CacheEnabled = true
	prx.CacheMinTTL = 20
	prx.CacheMaxTTL = 40
	u := testUpstream{
		ans: []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Rrtype: dns.TypeA,
				Name:   "host.",
				Ttl:    10,
			},
			A: net.IP{4, 3, 2, 1},
		}},
		ecsIP: clientIP,
	}
	prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u}
	err := prx.Start()
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, prx.Stop)

	// first request
	d := DNSContext{
		Req:  createHostTestMessage("host"),
		Addr: &net.TCPAddr{IP: clientIP},
	}
	err = prx.Resolve(&d)
	require.NoError(t, err)

	// get from cache - check min TTL
	ci, expired, key := prx.cache.getWithSubnet(d.Req, &net.IPNet{
		IP:   clientIP,
		Mask: net.CIDRMask(24, netutil.IPv4BitLen),
	})
	assert.False(t, expired)

	assert.Equal(t, key, msgToKeyWithSubnet(d.Req, clientIP, 24))
	assert.True(t, ci.m.Answer[0].Header().Ttl == prx.CacheMinTTL)

	// 2nd request
	clientIP = net.IP{1, 2, 4, 0}
	d.Req = createHostTestMessage("host")
	d.Addr = &net.TCPAddr{
		IP: clientIP,
	}
	u.ans = []dns.RR{&dns.A{
		Hdr: dns.RR_Header{
			Rrtype: dns.TypeA,
			Name:   "host.",
			Ttl:    60,
		},
		A: net.IP{4, 3, 2, 1},
	}}
	u.ecsIP = clientIP
	err = prx.Resolve(&d)
	require.NoError(t, err)

	// get from cache - check max TTL
	ci, expired, key = prx.cache.getWithSubnet(d.Req, &net.IPNet{
		IP:   clientIP,
		Mask: net.CIDRMask(24, netutil.IPv4BitLen),
	})
	assert.False(t, expired)
	assert.Equal(t, key, msgToKeyWithSubnet(d.Req, clientIP, 24))
	assert.True(t, ci.m.Answer[0].Header().Ttl == prx.CacheMaxTTL)
}

func createTestDNSCryptProxy(t *testing.T) (*Proxy, dnscrypt.ResolverConfig) {
	p := createTestProxy(t, nil)
	p.UDPListenAddr = nil
	p.TCPListenAddr = nil
	port := getFreePort()
	p.DNSCryptUDPListenAddr = []*net.UDPAddr{
		{Port: int(port), IP: net.ParseIP(listenIP)},
	}
	p.DNSCryptTCPListenAddr = []*net.TCPAddr{
		{Port: int(port), IP: net.ParseIP(listenIP)},
	}

	rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
	assert.Nil(t, err)

	cert, err := rc.CreateCert()
	assert.Nil(t, err)

	p.DNSCryptProviderName = rc.ProviderName
	p.DNSCryptResolverCert = cert
	return p, rc
}

func getFreePort() uint {
	l, _ := net.Listen("tcp", "127.0.0.1:0")
	port := uint(l.Addr().(*net.TCPAddr).Port)

	// stop listening immediately
	_ = l.Close()

	// sleep for 100ms (may be necessary on Windows)
	time.Sleep(100 * time.Millisecond)
	return port
}

func createTestProxy(t *testing.T, tlsConfig *tls.Config) *Proxy {
	t.Helper()

	p := Proxy{}

	if ip := net.ParseIP(listenIP); tlsConfig != nil {
		p.TLSListenAddr = []*net.TCPAddr{{IP: ip, Port: 0}}
		p.HTTPSListenAddr = []*net.TCPAddr{{IP: ip, Port: 0}}
		p.QUICListenAddr = []*net.UDPAddr{{IP: ip, Port: 0}}
		p.TLSConfig = tlsConfig
	} else {
		p.UDPListenAddr = []*net.UDPAddr{{IP: ip, Port: 0}}
		p.TCPListenAddr = []*net.TCPAddr{{IP: ip, Port: 0}}
	}
	upstreams := make([]upstream.Upstream, 0)
	dnsUpstream, err := upstream.AddressToUpstream(
		upstreamAddr,
		&upstream.Options{Timeout: defaultTimeout},
	)
	require.NoError(t, err)

	p.UpstreamConfig = &UpstreamConfig{}
	p.UpstreamConfig.Upstreams = append(upstreams, dnsUpstream)

	p.TrustedProxies = []string{"0.0.0.0/0", "::0/0"}

	return &p
}

func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) {
	defer func() {
		g.Done()
	}()

	req := createTestMessage()
	err := conn.WriteMsg(req)
	require.NoError(t, err)

	res, err := conn.ReadMsg()
	require.NoError(t, err)

	// We do not check if msg IDs match because the order of responses may
	// be different.

	require.NotNil(t, res)
	require.Lenf(t, res.Answer, 1, "wrong number of answers: %d", len(res.Answer))
	a, ok := res.Answer[0].(*dns.A)
	require.Truef(t, ok, "wrong answer type: %v", res.Answer[0])
	require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}

// sendTestMessagesAsync sends messages in parallel
// so that we could find race issues
func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
	g := &sync.WaitGroup{}
	g.Add(testMessagesCount)

	for i := 0; i < testMessagesCount; i++ {
		go sendTestMessageAsync(t, conn, g)
	}

	g.Wait()
}

func sendTestMessages(t *testing.T, conn *dns.Conn) {
	for i := 0; i < 10; i++ {
		req := createTestMessage()
		err := conn.WriteMsg(req)
		if err != nil {
			t.Fatalf("cannot write message #%d: %s", i, err)
		}

		res, err := conn.ReadMsg()
		if err != nil {
			t.Fatalf("cannot read response to message #%d: %s", i, err)
		}
		requireResponse(t, req, res)
	}
}

func createTestMessage() *dns.Msg {
	return createHostTestMessage("google-public-dns-a.google.com")
}

func createHostTestMessage(host string) *dns.Msg {
	req := dns.Msg{}
	req.Id = dns.Id()
	req.RecursionDesired = true
	name := host + "."
	req.Question = []dns.Question{
		{Name: name, Qtype: dns.TypeA, Qclass: dns.ClassINET},
	}
	return &req
}

func requireResponse(t testing.TB, req, reply *dns.Msg) {
	t.Helper()

	require.NotNil(t, reply)
	require.Lenf(t, reply.Answer, 1, "wrong number of answers: %d", len(reply.Answer))
	require.Equal(t, req.Id, reply.Id)

	a, ok := reply.Answer[0].(*dns.A)
	require.Truef(t, ok, "wrong answer type: %v", reply.Answer[0])

	require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}

func createServerTLSConfig(t *testing.T) (*tls.Config, []byte) {
	t.Helper()

	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		t.Fatalf("cannot generate RSA key: %s", err)
	}

	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	if err != nil {
		t.Fatalf("failed to generate serial number: %s", err)
	}

	notBefore := time.Now()
	notAfter := notBefore.Add(5 * 365 * time.Hour * 24)

	template := x509.Certificate{
		SerialNumber: serialNumber,
		Subject: pkix.Name{
			Organization: []string{"AdGuard Tests"},
		},
		NotBefore: notBefore,
		NotAfter:  notAfter,

		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
		IsCA:                  true,
	}
	template.DNSNames = append(template.DNSNames, tlsServerName)

	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
	if err != nil {
		t.Fatalf("failed to create certificate: %s", err)
	}

	certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
	keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})

	cert, err := tls.X509KeyPair(certPem, keyPem)
	if err != nil {
		t.Fatalf("failed to create certificate: %s", err)
	}

	return &tls.Config{Certificates: []tls.Certificate{cert}, ServerName: tlsServerName}, certPem
}

func publicKey(priv interface{}) interface{} {
	switch k := priv.(type) {
	case *rsa.PrivateKey:
		return &k.PublicKey
	case *ecdsa.PrivateKey:
		return &k.PublicKey
	default:
		return nil
	}
}

// Return the first A value in response
func getIPFromResponse(resp *dns.Msg) net.IP {
	for _, ans := range resp.Answer {
		a, ok := ans.(*dns.A)
		if !ok {
			continue
		}
		return a.A
	}
	return nil
}

type testUpstream struct {
	ans []dns.RR

	ecsIP      net.IP
	ecsReqIP   net.IP
	ecsReqMask int
}

// type check
var _ upstream.Upstream = (*testUpstream)(nil)

// Exchange implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
	resp = &dns.Msg{}
	resp.SetReply(m)

	if u.ans != nil {
		resp.Answer = append(resp.Answer, u.ans...)
	}

	ecs, _ := ecsFromMsg(m)
	if ecs != nil {
		u.ecsReqIP = ecs.IP
		u.ecsReqMask, _ = ecs.Mask.Size()
	}
	if u.ecsIP != nil {
		setECS(resp, u.ecsIP, 24)
	}

	return resp, nil
}

// Address implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Address() (addr string) {
	return ""
}

// Close implements the upstream.Upstream interface for *testUpstream.
func (u *testUpstream) Close() (err error) {
	return nil
}

func TestProxy_Resolve_withOptimisticResolver(t *testing.T) {
	const (
		host             = "some.domain.name."
		nonOptimisticTTL = 3600
	)

	buildCtx := func() (dctx *DNSContext) {
		req := &dns.Msg{
			MsgHdr: dns.MsgHdr{
				Id: dns.Id(),
			},
			Question: []dns.Question{{
				Name:   host,
				Qtype:  dns.TypeA,
				Qclass: dns.ClassINET,
			}},
		}

		return &DNSContext{
			Req: req,
		}
	}
	buildResp := func(req *dns.Msg, ttl uint32) (resp *dns.Msg) {
		resp = (&dns.Msg{}).SetReply(req)
		resp.Answer = []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Name:   host,
				Rrtype: dns.TypeA,
				Class:  dns.ClassINET,
				Ttl:    ttl,
			},
			A: net.IP{1, 2, 3, 4},
		}}

		return resp
	}

	p := &Proxy{
		Config: Config{
			CacheEnabled:    true,
			CacheOptimistic: true,
		},
	}

	p.initCache()
	out, in := make(chan unit), make(chan unit)
	p.shortFlighter.cr = &testCachingResolver{
		onReplyFromUpstream: func(dctx *DNSContext) (ok bool, err error) {
			dctx.Res = buildResp(dctx.Req, nonOptimisticTTL)

			return true, nil
		},
		onCacheResp: func(dctx *DNSContext) {
			// Report adding to cache is in process.
			out <- unit{}
			// Wait for tests to finish.
			<-in

			p.cacheResp(dctx)

			// Report adding tocache is finished.
			out <- unit{}
		},
	}

	// Two different contexts are made to emulate two different requests
	// with the same question section.
	firstCtx, secondCtx := buildCtx(), buildCtx()

	// Add expired response into cache.
	req := firstCtx.Req
	key := msgToKey(req)
	data := (&cacheItem{
		m: buildResp(req, 0),
		u: testUpsAddr,
	}).pack()
	items := glcache.New(glcache.Config{
		EnableLRU: true,
	})
	items.Set(key, data)
	p.cache.items = items

	err := p.Resolve(firstCtx)
	require.NoError(t, err)
	require.Len(t, firstCtx.Res.Answer, 1)

	assert.EqualValues(t, optimisticTTL, firstCtx.Res.Answer[0].Header().Ttl)

	// Wait for optimisticResolver to reach the tested function.
	<-out

	err = p.Resolve(secondCtx)
	require.NoError(t, err)
	require.Len(t, secondCtx.Res.Answer, 1)

	assert.EqualValues(t, optimisticTTL, secondCtx.Res.Answer[0].Header().Ttl)

	// Continue and wait for it to finish.
	in <- unit{}
	<-out

	// Should be served from cache.
	data = p.cache.items.Get(msgToKey(firstCtx.Req))
	unpacked, expired := p.cache.unpackItem(data, firstCtx.Req)
	require.False(t, expired)
	require.NotNil(t, unpacked)
	require.Len(t, unpacked.m.Answer, 1)

	assert.EqualValues(t, nonOptimisticTTL, unpacked.m.Answer[0].Header().Ttl)
}
07070100000055000081A4000000000000000000000001650C592100000BF4000000000000000000000000000000000000002400000000dnsproxy-0.55.0/proxy/proxycache.gopackage proxy

import (
	"net"

	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
)

// replyFromCache tries to get the response from general or subnet cache.
// Returns true on success.
func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) {
	var ci *cacheItem
	var hitMsg string
	var expired bool
	var key []byte

	if !p.Config.EnableEDNSClientSubnet {
		ci, expired, key = p.cache.get(d.Req)
		hitMsg = "serving cached response"
	} else if d.ReqECS != nil {
		ci, expired, key = p.cache.getWithSubnet(d.Req, d.ReqECS)
		hitMsg = "serving response from subnet cache"
	} else {
		ci, expired, key = p.cache.get(d.Req)
		hitMsg = "serving response from general cache"
	}

	if hit = ci != nil; !hit {
		return hit
	}

	d.Res = ci.m
	d.CachedUpstreamAddr = ci.u

	log.Debug("dnsproxy: cache: %s", hitMsg)

	if p.cache.optimistic && expired {
		// Build a reduced clone of the current context to avoid data race.
		minCtxClone := &DNSContext{
			// It is only read inside the optimistic resolver.
			CustomUpstreamConfig: d.CustomUpstreamConfig,
			ReqECS:               netutil.CloneIPNet(d.ReqECS),
		}
		if d.Req != nil {
			minCtxClone.Req = d.Req.Copy()
			addDO(minCtxClone.Req)
		}

		go p.shortFlighter.ResolveOnce(minCtxClone, key)
	}

	return hit
}

// cacheResp stores the response from d in general or subnet cache.
func (p *Proxy) cacheResp(d *DNSContext) {
	if !p.EnableEDNSClientSubnet {
		p.cache.set(d.Res, d.Upstream)

		return
	}

	switch ecs, scope := ecsFromMsg(d.Res); {
	case ecs != nil && d.ReqECS != nil:
		ones, bits := ecs.Mask.Size()
		reqOnes, _ := d.ReqECS.Mask.Size()

		// If FAMILY, SOURCE PREFIX-LENGTH, and SOURCE PREFIX-LENGTH bits of
		// ADDRESS in the response don't match the non-zero fields in the
		// corresponding query, the full response MUST be dropped.
		//
		// See RFC 7871 Section 7.3.
		//
		// TODO(a.meshkov):  The whole response MUST be dropped if ECS in it
		// doesn't correspond.
		if !ecs.IP.Mask(ecs.Mask).Equal(d.ReqECS.IP.Mask(d.ReqECS.Mask)) || ones != reqOnes {
			log.Debug("dnsproxy: cache: bad response: ecs %s does not match %s", ecs, d.ReqECS)

			return
		}

		// If SCOPE PREFIX-LENGTH is not longer than SOURCE PREFIX-LENGTH, store
		// SCOPE PREFIX-LENGTH bits of ADDRESS, and then mark the response as
		// valid for all addresses that fall within that range.
		//
		// See RFC 7871 Section 7.3.1.
		if scope < reqOnes {
			ecs.Mask = net.CIDRMask(scope, bits)
			ecs.IP = ecs.IP.Mask(ecs.Mask)
		}

		log.Debug("dnsproxy: cache: ecs option in response: %s", ecs)

		p.cache.setWithSubnet(d.Res, d.Upstream, ecs)
	case d.ReqECS != nil:
		// Cache the response for all subnets since the server doesn't support
		// EDNS Client Subnet option.
		p.cache.setWithSubnet(d.Res, d.Upstream, &net.IPNet{IP: nil, Mask: nil})
	default:
		p.cache.set(d.Res, d.Upstream)
	}
}

// ClearCache clears the DNS cache of p.
func (p *Proxy) ClearCache() {
	if p.cache != nil {
		p.cache.clearItems()
		p.cache.clearItemsWithSubnet()
		log.Debug("dnsproxy: cache: cleared")
	}
}
07070100000056000081A4000000000000000000000001650C5921000005AC000000000000000000000000000000000000002300000000dnsproxy-0.55.0/proxy/ratelimit.gopackage proxy

import (
	"net"
	"time"

	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	rate "github.com/beefsack/go-rate"
	gocache "github.com/patrickmn/go-cache"
	"golang.org/x/exp/slices"
)

func (p *Proxy) limiterForIP(ip string) interface{} {
	p.ratelimitLock.Lock()
	defer p.ratelimitLock.Unlock()
	if p.ratelimitBuckets == nil {
		p.ratelimitBuckets = gocache.New(time.Hour, time.Hour)
	}

	// check if ratelimiter for that IP already exists, if not, create
	value, found := p.ratelimitBuckets.Get(ip)
	if !found {
		value = rate.New(p.Ratelimit, time.Second)
		p.ratelimitBuckets.Set(ip, value, time.Hour)
	}

	return value
}

// isRatelimited checks if the specified IP is ratelimited.
func (p *Proxy) isRatelimited(addr net.Addr) (ok bool) {
	if p.Ratelimit <= 0 {
		// The ratelimit is disabled.
		return false
	}

	ip, _ := netutil.IPAndPortFromAddr(addr)
	if ip == nil {
		log.Printf("failed to split %v into host/port", addr)

		return false
	}

	ipStr := ip.String()

	if len(p.RatelimitWhitelist) > 0 {
		slices.Sort(p.RatelimitWhitelist)
		_, ok = slices.BinarySearch(p.RatelimitWhitelist, ipStr)
		if ok {
			// Don't ratelimit if the IP is allowlisted.
			return false
		}
	}

	value := p.limiterForIP(ipStr)
	rl, ok := value.(*rate.RateLimiter)
	if !ok {
		log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache")

		return false
	}

	allow, _ := rl.Try()

	return !allow
}
07070100000057000081A4000000000000000000000001650C592100000757000000000000000000000000000000000000002800000000dnsproxy-0.55.0/proxy/ratelimit_test.gopackage proxy

import (
	"net"
	"testing"
	"time"

	"github.com/miekg/dns"
)

func TestRatelimitingProxy(t *testing.T) {
	// Prepare the proxy server
	dnsProxy := createTestProxy(t, nil)
	dnsProxy.Ratelimit = 1 // just one request per second is allowed

	// Start listening
	err := dnsProxy.Start()
	if err != nil {
		t.Fatalf("cannot start the DNS proxy: %s", err)
	}

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	client := &dns.Client{Net: "udp", Timeout: 500 * time.Millisecond}

	// Send the first message (not blocked)
	req := createTestMessage()

	r, _, err := client.Exchange(req, addr.String())
	if err != nil {
		t.Fatalf("error in the first request: %s", err)
	}
	requireResponse(t, req, r)

	// Send the second message (blocked)
	req = createTestMessage()

	_, _, err = client.Exchange(req, addr.String())
	if err == nil {
		t.Fatalf("second request was not blocked")
	}

	// Stop the proxy
	err = dnsProxy.Stop()
	if err != nil {
		t.Fatalf("cannot stop the DNS proxy: %s", err)
	}
}

func TestRatelimiting(t *testing.T) {
	// rate limit is 1 per sec
	p := Proxy{}
	p.Ratelimit = 1

	addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1232}

	limited := p.isRatelimited(addr)

	if limited {
		t.Fatal("First request must have been allowed")
	}

	limited = p.isRatelimited(addr)

	if !limited {
		t.Fatal("Second request must have been ratelimited")
	}
}

func TestWhitelist(t *testing.T) {
	// rate limit is 1 per sec with whitelist
	p := Proxy{}
	p.Ratelimit = 1
	p.RatelimitWhitelist = []string{"127.0.0.1", "127.0.0.2", "127.0.0.125"}

	addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1232}

	limited := p.isRatelimited(addr)

	if limited {
		t.Fatal("First request must have been allowed")
	}

	limited = p.isRatelimited(addr)

	if limited {
		t.Fatal("Second request must have been allowed due to whitelist")
	}
}
07070100000058000081A4000000000000000000000001650C592100000549000000000000000000000000000000000000001E00000000dnsproxy-0.55.0/proxy/sema.gopackage proxy

import (
	"fmt"
)

// semaphore is the semaphore interface.  acquire will block until the
// resource can be acquired.  release never blocks.
type semaphore interface {
	acquire()
	release()
}

// noopSemaphore is a semaphore that has no limit.
type noopSemaphore struct{}

// acquire implements the semaphore interface for noopSemaphore.
func (noopSemaphore) acquire() {}

// release implements the semaphore interface for noopSemaphore.
func (noopSemaphore) release() {}

// newNoopSemaphore returns a new noopSemaphore.
func newNoopSemaphore() (s semaphore) { return noopSemaphore{} }

// sig is an alias for struct{} to type less.
type sig = struct{}

// chanSemaphore is a channel-based semaphore.
type chanSemaphore struct {
	c chan sig
}

// acquire implements the semaphore interface for *chanSemaphore.
func (c *chanSemaphore) acquire() {
	c.c <- sig{}
}

// release implements the semaphore interface for *chanSemaphore.
func (c *chanSemaphore) release() {
	select {
	case <-c.c:
	default:
	}
}

// newChanSemaphore returns a new chanSemaphore with the provided
// maximum resource number.  maxRes must be greater than zero.
func newChanSemaphore(maxRes int) (s semaphore, err error) {
	if maxRes < 1 {
		return nil, fmt.Errorf("bad maxRes: %d", maxRes)
	}

	s = &chanSemaphore{
		c: make(chan sig, maxRes),
	}
	return s, nil
}
07070100000059000081A4000000000000000000000001650C59210000135D000000000000000000000000000000000000002000000000dnsproxy-0.55.0/proxy/server.gopackage proxy

import (
	"context"
	"fmt"
	"net"
	"time"

	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
)

// startListeners configures and starts listener loops
func (p *Proxy) startListeners(ctx context.Context) error {
	err := p.createUDPListeners(ctx)
	if err != nil {
		return err
	}

	err = p.createTCPListeners(ctx)
	if err != nil {
		return err
	}

	err = p.createTLSListeners()
	if err != nil {
		return err
	}

	err = p.createHTTPSListeners()
	if err != nil {
		return err
	}

	err = p.createQUICListeners()
	if err != nil {
		return err
	}

	err = p.createDNSCryptListeners()
	if err != nil {
		return err
	}

	for _, l := range p.udpListen {
		go p.udpPacketLoop(l, p.requestGoroutinesSema)
	}

	for _, l := range p.tcpListen {
		go p.tcpPacketLoop(l, ProtoTCP, p.requestGoroutinesSema)
	}

	for _, l := range p.tlsListen {
		go p.tcpPacketLoop(l, ProtoTLS, p.requestGoroutinesSema)
	}

	for _, l := range p.httpsListen {
		go func(l net.Listener) { _ = p.httpsServer.Serve(l) }(l)
	}

	for _, l := range p.h3Listen {
		go func(l *quic.EarlyListener) { _ = p.h3Server.ServeListener(l) }(l)
	}

	for _, l := range p.quicListen {
		go p.quicPacketLoop(l, p.requestGoroutinesSema)
	}

	for _, l := range p.dnsCryptUDPListen {
		go func(l *net.UDPConn) { _ = p.dnsCryptServer.ServeUDP(l) }(l)
	}

	for _, l := range p.dnsCryptTCPListen {
		go func(l net.Listener) { _ = p.dnsCryptServer.ServeTCP(l) }(l)
	}

	return nil
}

// handleDNSRequest processes the incoming packet bytes and returns with an optional response packet.
func (p *Proxy) handleDNSRequest(d *DNSContext) error {
	d.StartTime = time.Now()
	p.logDNSMessage(d.Req)

	if d.Req.Response {
		log.Debug("Dropping incoming Reply packet from %s", d.Addr.String())
		return nil
	}

	if p.BeforeRequestHandler != nil {
		ok, err := p.BeforeRequestHandler(p, d)
		if err != nil {
			log.Error("Error in the BeforeRequestHandler: %s", err)
			d.Res = p.genServerFailure(d.Req)
			p.respond(d)
			return nil
		}
		if !ok {
			return nil // do nothing, don't reply
		}
	}

	// ratelimit based on IP only, protects CPU cycles and outbound connections
	if d.Proto == ProtoUDP && p.isRatelimited(d.Addr) {
		log.Tracef("Ratelimiting %v based on IP only", d.Addr)
		return nil // do nothing, don't reply, we got ratelimited
	}

	if len(d.Req.Question) != 1 {
		log.Debug("got invalid number of questions: %v", len(d.Req.Question))
		d.Res = p.genServerFailure(d.Req)
	}

	// refuse ANY requests (anti-DDOS measure)
	if p.RefuseAny && len(d.Req.Question) > 0 && d.Req.Question[0].Qtype == dns.TypeANY {
		log.Tracef("Refusing type=ANY request")
		d.Res = p.genNotImpl(d.Req)
	}

	var err error

	if d.Res == nil {
		if len(p.UpstreamConfig.Upstreams) == 0 {
			panic("SHOULD NOT HAPPEN: no default upstreams specified")
		}

		// execute the DNS request
		// if there is a custom middleware configured, use it
		if p.RequestHandler != nil {
			err = p.RequestHandler(p, d)
		} else {
			err = p.Resolve(d)
		}

		if err != nil {
			err = fmt.Errorf("talking to dns upstream: %w", err)
		}
	}

	p.logDNSMessage(d.Res)
	p.respond(d)

	return err
}

// respond writes the specified response to the client (or does nothing if d.Res is empty)
func (p *Proxy) respond(d *DNSContext) {
	// d.Conn can be nil in the case of a DoH request.
	if d.Conn != nil {
		_ = d.Conn.SetWriteDeadline(time.Now().Add(defaultTimeout))
	}

	var err error

	switch d.Proto {
	case ProtoUDP:
		err = p.respondUDP(d)
	case ProtoTCP:
		err = p.respondTCP(d)
	case ProtoTLS:
		err = p.respondTCP(d)
	case ProtoHTTPS:
		err = p.respondHTTPS(d)
	case ProtoQUIC:
		err = p.respondQUIC(d)
	case ProtoDNSCrypt:
		err = p.respondDNSCrypt(d)
	default:
		err = fmt.Errorf("SHOULD NOT HAPPEN - unknown protocol: %s", d.Proto)
	}

	if err != nil {
		logWithNonCrit(err, fmt.Sprintf("responding %s request", d.Proto))
	}
}

// Set TTL value of all records according to our settings
func (p *Proxy) setMinMaxTTL(r *dns.Msg) {
	for _, rr := range r.Answer {
		originalTTL := rr.Header().Ttl
		newTTL := respectTTLOverrides(originalTTL, p.CacheMinTTL, p.CacheMaxTTL)

		if originalTTL != newTTL {
			log.Debug("Override TTL from %d to %d", originalTTL, newTTL)
			rr.Header().Ttl = newTTL
		}
	}
}

func (p *Proxy) genServerFailure(request *dns.Msg) *dns.Msg {
	return p.genWithRCode(request, dns.RcodeServerFailure)
}

func (p *Proxy) genNotImpl(request *dns.Msg) (resp *dns.Msg) {
	resp = p.genWithRCode(request, dns.RcodeNotImplemented)
	// NOTIMPL without EDNS is treated as 'we don't support EDNS', so
	// explicitly set it.
	resp.SetEdns0(1452, false)

	return resp
}

func (p *Proxy) genWithRCode(req *dns.Msg, code int) (resp *dns.Msg) {
	resp = &dns.Msg{}
	resp.SetRcode(req, code)
	resp.RecursionAvailable = true

	return resp
}

func (p *Proxy) logDNSMessage(m *dns.Msg) {
	if m == nil {
		return
	}

	if m.Response {
		log.Tracef("OUT: %s", m)
	} else {
		log.Tracef("IN: %s", m)
	}
}
0707010000005A000081A4000000000000000000000001650C59210000095C000000000000000000000000000000000000002900000000dnsproxy-0.55.0/proxy/server_dnscrypt.gopackage proxy

import (
	"fmt"
	"net"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/miekg/dns"
)

func (p *Proxy) createDNSCryptListeners() (err error) {
	if len(p.DNSCryptUDPListenAddr) == 0 && len(p.DNSCryptTCPListenAddr) == 0 {
		// Do nothing if DNSCrypt listen addresses are not specified.
		return nil
	}

	if p.DNSCryptResolverCert == nil || p.DNSCryptProviderName == "" {
		return errors.Error("invalid DNSCrypt configuration: no certificate or provider name")
	}

	log.Info("Initializing DNSCrypt: %s", p.DNSCryptProviderName)
	p.dnsCryptServer = &dnscrypt.Server{
		ProviderName: p.DNSCryptProviderName,
		ResolverCert: p.DNSCryptResolverCert,
		Handler: &dnsCryptHandler{
			proxy: p,

			requestGoroutinesSema: p.requestGoroutinesSema,
		},
	}

	for _, a := range p.DNSCryptUDPListenAddr {
		log.Info("Creating a DNSCrypt UDP listener")
		udpListen, lErr := net.ListenUDP("udp", a)
		if lErr != nil {
			return fmt.Errorf("listening to dnscrypt udp socket: %w", lErr)
		}

		p.dnsCryptUDPListen = append(p.dnsCryptUDPListen, udpListen)
		log.Info("Listening for DNSCrypt messages on udp://%s", udpListen.LocalAddr())
	}

	for _, a := range p.DNSCryptTCPListenAddr {
		log.Info("Creating a DNSCrypt TCP listener")
		tcpListen, lErr := net.ListenTCP("tcp", a)
		if lErr != nil {
			return fmt.Errorf("listening to dnscrypt tcp socket: %w", lErr)
		}

		p.dnsCryptTCPListen = append(p.dnsCryptTCPListen, tcpListen)
		log.Info("Listening for DNSCrypt messages on tcp://%s", tcpListen.Addr())
	}

	return nil
}

// dnsCryptHandler - dnscrypt.Handler implementation
type dnsCryptHandler struct {
	proxy *Proxy

	requestGoroutinesSema semaphore
}

// compile-time type check
var _ dnscrypt.Handler = &dnsCryptHandler{}

// ServeDNS - processes the DNS query
func (h *dnsCryptHandler) ServeDNS(rw dnscrypt.ResponseWriter, req *dns.Msg) error {
	d := h.proxy.newDNSContext(ProtoDNSCrypt, req)
	d.Addr = rw.RemoteAddr()
	d.DNSCryptResponseWriter = rw

	h.requestGoroutinesSema.acquire()
	defer h.requestGoroutinesSema.release()

	return h.proxy.handleDNSRequest(d)
}

// Writes a response to the UDP client
func (p *Proxy) respondDNSCrypt(d *DNSContext) error {
	if d.Res == nil {
		// If no response has been written, do nothing and let it drop
		return nil
	}

	return d.DNSCryptResponseWriter.WriteMsg(d.Res)
}
0707010000005B000081A4000000000000000000000001650C59210000044A000000000000000000000000000000000000002E00000000dnsproxy-0.55.0/proxy/server_dnscrypt_test.gopackage proxy

import (
	"fmt"
	"net"
	"testing"

	"github.com/ameshkov/dnscrypt/v2"
	"github.com/ameshkov/dnsstamps"
	"github.com/stretchr/testify/assert"
)

func TestDNSCryptProxy(t *testing.T) {
	// Prepare the proxy server
	dnsProxy, rc := createTestDNSCryptProxy(t)

	// Start listening
	err := dnsProxy.Start()
	assert.Nil(t, err)
	defer func() {
		assert.Nil(t, dnsProxy.Stop())
	}()

	// Generate a DNS stamp
	addr := fmt.Sprintf("%s:%d", listenIP, dnsProxy.Addr(ProtoDNSCrypt).(*net.UDPAddr).Port)
	stamp, err := rc.CreateStamp(addr)
	assert.Nil(t, err)

	// Test DNSCrypt proxy on both UDP and TCP
	checkDNSCryptProxy(t, "udp", stamp)
	checkDNSCryptProxy(t, "tcp", stamp)
}

func checkDNSCryptProxy(t *testing.T, proto string, stamp dnsstamps.ServerStamp) {
	// Create a DNSCrypt client
	c := &dnscrypt.Client{
		Timeout: defaultTimeout,
		Net:     proto,
	}

	// Fetch the server certificate
	ri, err := c.DialStamp(stamp)
	assert.Nil(t, err)

	// Send the test message
	msg := createTestMessage()
	reply, err := c.Exchange(msg, ri)
	assert.Nil(t, err)
	requireResponse(t, msg, reply)
}
0707010000005C000081A4000000000000000000000001650C592100001CE9000000000000000000000000000000000000002600000000dnsproxy-0.55.0/proxy/server_https.gopackage proxy

import (
	"crypto/tls"
	"encoding/base64"
	"fmt"
	"io"
	"net"
	"net/http"
	"strconv"
	"strings"

	"github.com/AdguardTeam/golibs/httphdr"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
	"golang.org/x/net/http2"
)

// listenHTTP creates instances of TLS listeners that will be used to run an
// H1/H2 server.  Returns the address the listener actually listens to (useful
// in the case if port 0 is specified).
func (p *Proxy) listenHTTP(addr *net.TCPAddr) (laddr *net.TCPAddr, err error) {
	tcpListen, err := net.ListenTCP("tcp", addr)
	if err != nil {
		return nil, fmt.Errorf("tcp listener: %w", err)
	}
	log.Info("Listening to https://%s", tcpListen.Addr())

	tlsConfig := p.TLSConfig.Clone()
	tlsConfig.NextProtos = []string{http2.NextProtoTLS, "http/1.1"}

	tlsListen := tls.NewListener(tcpListen, tlsConfig)
	p.httpsListen = append(p.httpsListen, tlsListen)

	return tcpListen.Addr().(*net.TCPAddr), nil
}

// listenH3 creates instances of QUIC listeners that will be used for running
// an HTTP/3 server.
func (p *Proxy) listenH3(addr *net.UDPAddr) (err error) {
	tlsConfig := p.TLSConfig.Clone()
	tlsConfig.NextProtos = []string{"h3"}
	quicListen, err := quic.ListenAddrEarly(addr.String(), tlsConfig, newServerQUICConfig())
	if err != nil {
		return fmt.Errorf("quic listener: %w", err)
	}
	log.Info("Listening to h3://%s", quicListen.Addr())

	p.h3Listen = append(p.h3Listen, quicListen)

	return nil
}

// createHTTPSListeners creates TCP/UDP listeners and HTTP/H3 servers.
func (p *Proxy) createHTTPSListeners() (err error) {
	p.httpsServer = &http.Server{
		Handler:           p,
		ReadHeaderTimeout: defaultTimeout,
		WriteTimeout:      defaultTimeout,
	}

	if p.HTTP3 {
		p.h3Server = &http3.Server{
			Handler: p,
		}
	}

	for _, addr := range p.HTTPSListenAddr {
		log.Info("Creating an HTTPS server")

		tcpAddr, lErr := p.listenHTTP(addr)
		if lErr != nil {
			return fmt.Errorf("failed to start HTTPS server on %s: %w", addr, lErr)
		}

		if p.HTTP3 {
			// HTTP/3 server listens to the same pair IP:port as the one HTTP/2
			// server listens to.
			udpAddr := &net.UDPAddr{IP: tcpAddr.IP, Port: tcpAddr.Port}
			err = p.listenH3(udpAddr)
			if err != nil {
				return fmt.Errorf("failed to start HTTP/3 server on %s: %w", udpAddr, err)
			}
		}
	}

	return nil
}

// ServeHTTP is the http.Handler implementation that handles DoH queries.
// Here is what it returns:
//
//   - http.StatusBadRequest if there is no DNS request data;
//   - http.StatusUnsupportedMediaType if request content type is not
//     "application/dns-message";
//   - http.StatusMethodNotAllowed if request method is not GET or POST.
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	log.Tracef("Incoming HTTPS request on %s", r.URL)

	var buf []byte
	var err error

	switch r.Method {
	case http.MethodGet:
		dnsParam := r.URL.Query().Get("dns")
		buf, err = base64.RawURLEncoding.DecodeString(dnsParam)
		if len(buf) == 0 || err != nil {
			log.Tracef("Cannot parse DNS request from %s", dnsParam)
			http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)

			return
		}
	case http.MethodPost:
		contentType := r.Header.Get("Content-Type")
		if contentType != "application/dns-message" {
			log.Tracef("Unsupported media type: %s", contentType)
			http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType)

			return
		}

		buf, err = io.ReadAll(r.Body)
		if err != nil {
			log.Tracef("Cannot read the request body: %s", err)
			http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)

			return
		}
		defer log.OnCloserError(r.Body, log.DEBUG)
	default:
		log.Tracef("Wrong HTTP method: %s", r.Method)
		http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)

		return
	}

	req := &dns.Msg{}
	if err = req.Unpack(buf); err != nil {
		log.Tracef("msg.Unpack: %s", err)
		http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)

		return
	}

	addr, prx, err := remoteAddr(r)
	if err != nil {
		log.Debug("warning: getting real ip: %s", err)
	}

	d := p.newDNSContext(ProtoHTTPS, req)
	d.Addr = addr
	d.HTTPRequest = r
	d.HTTPResponseWriter = w

	if prx != nil {
		ip, _ := netutil.IPAndPortFromAddr(prx)
		log.Debug("request came from proxy server %s", prx)
		if !p.proxyVerifier.Contains(ip) {
			log.Debug("proxy %s is not trusted, using original remote addr", ip)
			d.Addr = prx
		}
	}

	err = p.handleDNSRequest(d)
	if err != nil {
		log.Tracef("error handling DNS (%s) request: %s", d.Proto, err)
	}
}

// Writes a response to the DoH client.
func (p *Proxy) respondHTTPS(d *DNSContext) (err error) {
	resp := d.Res
	w := d.HTTPResponseWriter

	if resp == nil {
		// Indicate the response's absence via a http.StatusInternalServerError.
		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

		return nil
	}

	bytes, err := resp.Pack()
	if err != nil {
		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

		return fmt.Errorf("packing message: %w", err)
	}

	if srvName := p.Config.HTTPSServerName; srvName != "" {
		w.Header().Set(httphdr.Server, srvName)
	}

	w.Header().Set(httphdr.ContentType, "application/dns-message")
	_, err = w.Write(bytes)

	return err
}

// realIPFromHdrs extracts the actual client's IP address from the first
// suitable r's header.  It returns nil if r doesn't contain any information
// about real client's IP address.  Current headers priority is:
//
//  1. [httphdr.CFConnectingIP]
//  2. [httphdr.TrueClientIP]
//  3. [httphdr.XRealIP]
//  4. [httphdr.XForwardedFor]
func realIPFromHdrs(r *http.Request) (realIP net.IP) {
	for _, h := range []string{
		httphdr.CFConnectingIP,
		httphdr.TrueClientIP,
		httphdr.XRealIP,
	} {
		realIP = net.ParseIP(strings.TrimSpace(r.Header.Get(h)))
		if realIP != nil {
			return realIP
		}
	}

	xff := r.Header.Get(httphdr.XForwardedFor)
	firstComma := strings.IndexByte(xff, ',')
	if firstComma == -1 {
		return net.ParseIP(strings.TrimSpace(xff))
	}

	return net.ParseIP(strings.TrimSpace(xff[:firstComma]))
}

// remoteAddr returns the real client's address and the IP address of the latest
// proxy server if any.
func remoteAddr(r *http.Request) (addr, prx net.Addr, err error) {
	var hostStr, portStr string
	if hostStr, portStr, err = net.SplitHostPort(r.RemoteAddr); err != nil {
		return nil, nil, err
	}

	var port int
	if port, err = strconv.Atoi(portStr); err != nil {
		return nil, nil, err
	}

	host := net.ParseIP(hostStr)
	if host == nil {
		return nil, nil, fmt.Errorf("invalid ip: %s", hostStr)
	}

	h3 := r.Context().Value(http3.ServerContextKey) != nil

	if realIP := realIPFromHdrs(r); realIP != nil {
		log.Tracef("Using IP address from HTTP request: %s", realIP)

		// TODO(a.garipov): Add port if we can get it from headers like
		// X-Real-Port, X-Forwarded-Port, etc.
		if h3 {
			addr = &net.UDPAddr{IP: realIP, Port: 0}
			prx = &net.UDPAddr{IP: host, Port: port}
		} else {
			addr = &net.TCPAddr{IP: realIP, Port: 0}
			prx = &net.TCPAddr{IP: host, Port: port}
		}

		return addr, prx, nil
	}

	if h3 {
		return &net.UDPAddr{IP: host, Port: port}, nil, nil
	}

	return &net.TCPAddr{IP: host, Port: port}, nil, nil
}
0707010000005D000081A4000000000000000000000001650C59210000273E000000000000000000000000000000000000002B00000000dnsproxy-0.55.0/proxy/server_https_test.gopackage proxy

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/base64"
	"fmt"
	"io"
	"net"
	"net/http"
	"net/url"
	"strings"
	"testing"

	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestHttpsProxy(t *testing.T) {
	testCases := []struct {
		name  string
		http3 bool
	}{{
		name:  "https_proxy",
		http3: false,
	}, {
		name:  "h3_proxy",
		http3: true,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			// Prepare dnsProxy with its configuration.
			tlsConf, caPem := createServerTLSConfig(t)
			dnsProxy := createTestProxy(t, tlsConf)
			dnsProxy.HTTP3 = tc.http3

			// Run the proxy.
			err := dnsProxy.Start()
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

			// Create the HTTP client that we'll be using for this test.
			client := createTestHTTPClient(dnsProxy, caPem, tc.http3)

			// Prepare a test message to be sent to the server.
			msg := createTestMessage()

			// Send the test message and check if the response is what we
			// expected.
			resp := sendTestDoHMessage(t, client, msg, nil)
			requireResponse(t, msg, resp)
		})
	}
}

func TestProxy_trustedProxies(t *testing.T) {
	clientIP, proxyIP := net.IP{1, 2, 3, 4}, net.IP{127, 0, 0, 1}

	doRequest := func(t *testing.T, proxyAddr string, expectedClientIP net.IP) {
		// Prepare the proxy server.
		tlsConf, caPem := createServerTLSConfig(t)
		dnsProxy := createTestProxy(t, tlsConf)

		var gotAddr net.Addr
		dnsProxy.RequestHandler = func(_ *Proxy, d *DNSContext) (err error) {
			gotAddr = d.Addr

			return dnsProxy.Resolve(d)
		}

		client := createTestHTTPClient(dnsProxy, caPem, false)

		msg := createTestMessage()

		dnsProxy.TrustedProxies = []string{proxyAddr}

		// Start listening.
		serr := dnsProxy.Start()
		require.NoError(t, serr)
		testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

		hdrs := map[string]string{
			"X-Forwarded-For": strings.Join([]string{clientIP.String(), proxyIP.String()}, ","),
		}

		resp := sendTestDoHMessage(t, client, msg, hdrs)
		requireResponse(t, msg, resp)

		ip, _ := netutil.IPAndPortFromAddr(gotAddr)
		require.True(t, ip.Equal(expectedClientIP))
	}

	t.Run("success", func(t *testing.T) {
		doRequest(t, proxyIP.String(), clientIP)
	})

	t.Run("not_in_trusted", func(t *testing.T) {
		doRequest(t, "127.0.0.2", proxyIP)
	})
}

func TestAddrsFromRequest(t *testing.T) {
	theIP, anotherIP := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5}
	theIPStr, anotherIPStr := theIP.String(), anotherIP.String()

	testCases := []struct {
		name   string
		hdrs   map[string]string
		wantIP net.IP
	}{{
		name: "cf-connecting-ip",
		hdrs: map[string]string{
			"CF-Connecting-IP": theIPStr,
		},
		wantIP: theIP,
	}, {
		name: "true-client-ip",
		hdrs: map[string]string{
			"True-Client-IP": theIPStr,
		},
		wantIP: theIP,
	}, {
		name: "x-real-ip",
		hdrs: map[string]string{
			"X-Real-IP": theIPStr,
		},
		wantIP: theIP,
	}, {
		name: "no_any",
		hdrs: map[string]string{
			"CF-Connecting-IP": "invalid",
			"True-Client-IP":   "invalid",
			"X-Real-IP":        "invalid",
		},
		wantIP: nil,
	}, {
		name: "priority",
		hdrs: map[string]string{
			"X-Forwarded-For":  strings.Join([]string{anotherIPStr, theIPStr}, ","),
			"True-Client-IP":   anotherIPStr,
			"X-Real-IP":        anotherIPStr,
			"CF-Connecting-IP": theIPStr,
		},
		wantIP: theIP,
	}, {
		name: "x-forwarded-for_simple",
		hdrs: map[string]string{
			"X-Forwarded-For": strings.Join([]string{anotherIPStr, theIPStr}, ","),
		},
		wantIP: anotherIP,
	}, {
		name: "x-forwarded-for_single",
		hdrs: map[string]string{
			"X-Forwarded-For": theIPStr,
		},
		wantIP: theIP,
	}, {
		name: "x-forwarded-for_invalid_proxy",
		hdrs: map[string]string{
			"X-Forwarded-For": strings.Join([]string{theIPStr, "invalid"}, ","),
		},
		wantIP: theIP,
	}, {
		name: "x-forwarded-for_empty",
		hdrs: map[string]string{
			"X-Forwarded-For": "",
		},
		wantIP: nil,
	}, {
		name: "x-forwarded-for_redundant_spaces",
		hdrs: map[string]string{
			"X-Forwarded-For": "  " + theIPStr + "   ,\t" + anotherIPStr,
		},
		wantIP: theIP,
	}, {
		name: "cf-connecting-ip_redundant_spaces",
		hdrs: map[string]string{
			"CF-Connecting-IP": "  " + theIPStr + "\t",
		},
		wantIP: theIP,
	}}

	for _, tc := range testCases {
		r, err := http.NewRequest(http.MethodGet, "localhost", nil)
		require.NoError(t, err)

		for h, v := range tc.hdrs {
			r.Header.Set(h, v)
		}

		t.Run(tc.name, func(t *testing.T) {
			ip := realIPFromHdrs(r)
			assert.True(t, tc.wantIP.Equal(ip))
		})
	}
}

func TestRemoteAddr(t *testing.T) {
	theIP, anotherIP, thirdIP := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5}, net.IP{1, 2, 3, 6}
	theIPStr, anotherIPStr, thirdIPStr := theIP.String(), anotherIP.String(), thirdIP.String()
	rAddr := &net.TCPAddr{IP: theIP, Port: 1}

	testCases := []struct {
		name       string
		remoteAddr string
		hdrs       map[string]string
		wantErr    string
		wantIP     net.IP
		wantProxy  net.IP
	}{{
		name:       "no_proxy",
		remoteAddr: rAddr.String(),
		hdrs:       nil,
		wantErr:    "",
		wantIP:     theIP,
		wantProxy:  nil,
	}, {
		name:       "proxied_with_cloudflare",
		remoteAddr: rAddr.String(),
		hdrs: map[string]string{
			"CF-Connecting-IP": anotherIPStr,
		},
		wantErr:   "",
		wantIP:    anotherIP,
		wantProxy: theIP,
	}, {
		name:       "proxied_once",
		remoteAddr: rAddr.String(),
		hdrs: map[string]string{
			"X-Forwarded-For": anotherIPStr,
		},
		wantErr:   "",
		wantIP:    anotherIP,
		wantProxy: theIP,
	}, {
		name:       "proxied_multiple",
		remoteAddr: rAddr.String(),
		hdrs: map[string]string{
			"X-Forwarded-For": strings.Join([]string{anotherIPStr, thirdIPStr}, ","),
		},
		wantErr:   "",
		wantIP:    anotherIP,
		wantProxy: theIP,
	}, {
		name:       "no_port",
		remoteAddr: theIPStr,
		hdrs:       nil,
		wantErr:    "address " + theIPStr + ": missing port in address",
		wantIP:     nil,
		wantProxy:  nil,
	}, {
		name:       "bad_port",
		remoteAddr: theIPStr + ":notport",
		hdrs:       nil,
		wantErr:    "strconv.Atoi: parsing \"notport\": invalid syntax",
		wantIP:     nil,
		wantProxy:  nil,
	}, {
		name:       "bad_host",
		remoteAddr: "host:1",
		hdrs:       nil,
		wantErr:    "invalid ip: host",
		wantIP:     nil,
		wantProxy:  nil,
	}, {
		name:       "bad_proxied_host",
		remoteAddr: "host:1",
		hdrs: map[string]string{
			"CF-Connecting-IP": theIPStr,
		},
		wantErr:   "invalid ip: host",
		wantIP:    nil,
		wantProxy: nil,
	}}

	for _, tc := range testCases {
		r, err := http.NewRequest(http.MethodGet, "localhost", nil)
		require.NoError(t, err)

		r.RemoteAddr = tc.remoteAddr
		for h, v := range tc.hdrs {
			r.Header.Set(h, v)
		}

		t.Run(tc.name, func(t *testing.T) {
			addr, prx, aErr := remoteAddr(r)
			if tc.wantErr != "" {
				assert.Equal(t, tc.wantErr, aErr.Error())

				return
			}

			require.NoError(t, aErr)

			ip, _ := netutil.IPAndPortFromAddr(addr)
			assert.True(t, ip.Equal(tc.wantIP))

			prxIP, _ := netutil.IPAndPortFromAddr(prx)
			assert.True(t, tc.wantProxy.Equal(prxIP))
		})
	}
}

// sendTestDoHMessage sends the specified DNS message using client and returns
// the DNS response.
func sendTestDoHMessage(
	t *testing.T,
	client *http.Client,
	m *dns.Msg,
	hdrs map[string]string,
) (resp *dns.Msg) {
	packed, err := m.Pack()
	require.NoError(t, err)

	u := url.URL{
		Scheme:   "https",
		Host:     tlsServerName,
		Path:     "/dns-query",
		RawQuery: fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(packed)),
	}

	method := http.MethodGet
	if _, ok := client.Transport.(*http3.RoundTripper); ok {
		// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
		method = http3.MethodGet0RTT
	}

	req, err := http.NewRequest(method, u.String(), nil)
	require.NoError(t, err)

	req.Header.Set("Content-Type", "application/dns-message")
	req.Header.Set("Accept", "application/dns-message")

	for k, v := range hdrs {
		req.Header.Set(k, v)
	}

	httpResp, err := client.Do(req) // nolint:bodyclose
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, httpResp.Body.Close)

	require.True(
		t,
		httpResp.ProtoAtLeast(2, 0),
		"the proto is too old: %s",
		httpResp.Proto,
	)

	body, err := io.ReadAll(httpResp.Body)
	require.NoError(t, err)

	resp = &dns.Msg{}
	err = resp.Unpack(body)
	require.NoError(t, err)

	return resp
}

// createTestHTTPClient creates an *http.Client that will be used to send
// requests to the specified dnsProxy.
func createTestHTTPClient(dnsProxy *Proxy, caPem []byte, http3Enabled bool) (client *http.Client) {
	// prepare roots list so that the server cert was successfully validated.
	roots := x509.NewCertPool()
	roots.AppendCertsFromPEM(caPem)
	tlsClientConfig := &tls.Config{
		ServerName: tlsServerName,
		RootCAs:    roots,
	}

	var transport http.RoundTripper

	if http3Enabled {
		tlsClientConfig.NextProtos = []string{"h3"}

		transport = &http3.RoundTripper{
			Dial: func(
				ctx context.Context,
				_ string,
				tlsCfg *tls.Config,
				cfg *quic.Config,
			) (quic.EarlyConnection, error) {
				addr := dnsProxy.Addr(ProtoHTTPS).String()
				return quic.DialAddrEarly(ctx, addr, tlsCfg, cfg)
			},
			TLSClientConfig:    tlsClientConfig,
			QuicConfig:         &quic.Config{},
			DisableCompression: true,
		}
	} else {
		dialer := &net.Dialer{
			Timeout: defaultTimeout,
		}
		dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) {
			// Route request to the DNS-over-HTTPS server address.
			return dialer.DialContext(ctx, network, dnsProxy.Addr(ProtoHTTPS).String())
		}

		tlsClientConfig.NextProtos = []string{"h2", "http/1.1"}
		transport = &http.Transport{
			TLSClientConfig:    tlsClientConfig,
			DisableCompression: true,
			DialContext:        dialContext,
			ForceAttemptHTTP2:  true,
		}
	}

	return &http.Client{
		Transport: transport,
		Timeout:   defaultTimeout,
	}
}
0707010000005E000081A4000000000000000000000001650C5921000038CD000000000000000000000000000000000000002500000000dnsproxy-0.55.0/proxy/server_quic.gopackage proxy

import (
	"context"
	"encoding/binary"
	"fmt"
	"io"
	"math"
	"net"
	"time"

	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/bluele/gcache"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
)

// NextProtoDQ is the ALPN token for DoQ. During connection establishment,
// DNS/QUIC support is indicated by selecting the ALPN token "dq" in the
// crypto handshake.
// DoQ RFC: https://www.rfc-editor.org/rfc/rfc9250.html
const NextProtoDQ = "doq"

// compatProtoDQ is a list of ALPN tokens used by a QUIC connection.
// NextProtoDQ is the latest draft version supported by dnsproxy, but it also
// includes previous drafts.
var compatProtoDQ = []string{NextProtoDQ, "doq-i02", "doq-i00", "dq"}

// maxQUICIdleTimeout is maximum QUIC idle timeout.  The default value in
// quic-go is 30 seconds, but our internal tests show that a higher value works
// better for clients written with ngtcp2.
const maxQUICIdleTimeout = 5 * time.Minute

// quicAddrValidatorCacheSize is the size of the cache that we use in the QUIC
// address validator.  The value is chosen arbitrarily and we should consider
// making it configurable.
// TODO(ameshkov): make it configurable.
const quicAddrValidatorCacheSize = 1000

// quicAddrValidatorCacheTTL is time-to-live for cache items in the QUIC address
// validator.  The value is chosen arbitrarily and we should consider making it
// configurable.
// TODO(ameshkov): make it configurable.
const quicAddrValidatorCacheTTL = 30 * time.Minute

const (
	// DoQCodeNoError is used when the connection or stream needs to be closed,
	// but there is no error to signal.
	DoQCodeNoError quic.ApplicationErrorCode = 0
	// DoQCodeInternalError signals that the DoQ implementation encountered
	// an internal error and is incapable of pursuing the transaction or the
	// connection.
	DoQCodeInternalError quic.ApplicationErrorCode = 1
	// DoQCodeProtocolError signals that the DoQ implementation encountered
	// a protocol error and is forcibly aborting the connection.
	DoQCodeProtocolError quic.ApplicationErrorCode = 2
)

// createQUICListeners creates QUIC listeners for the DoQ server.
func (p *Proxy) createQUICListeners() error {
	for _, a := range p.QUICListenAddr {
		log.Info("Creating a QUIC listener")
		tlsConfig := p.TLSConfig.Clone()
		tlsConfig.NextProtos = compatProtoDQ
		quicListen, err := quic.ListenAddrEarly(
			a.String(),
			tlsConfig,
			newServerQUICConfig(),
		)
		if err != nil {
			return fmt.Errorf("quic listener: %w", err)
		}

		p.quicListen = append(p.quicListen, quicListen)
		log.Info("Listening to quic://%s", quicListen.Addr())
	}
	return nil
}

// quicPacketLoop listens for incoming QUIC packets.
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, requestGoroutinesSema semaphore) {
	log.Info("Entering the DNS-over-QUIC listener loop on %s", l.Addr())
	for {
		conn, err := l.Accept(context.Background())
		if err != nil {
			if isQUICErrorForDebugLog(err) {
				log.Debug("accepting quic conn: closed or timed out: %s", err)
			} else {
				log.Error("accepting quic conn: %s", err)
			}

			break
		}

		requestGoroutinesSema.acquire()
		go func() {
			p.handleQUICConnection(conn, requestGoroutinesSema)
			requestGoroutinesSema.release()
		}()
	}
}

// handleQUICConnection handles a new QUIC connection.  It waits for new streams
// and passes them to handleQUICStream.
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) handleQUICConnection(conn quic.Connection, requestGoroutinesSema semaphore) {
	for {
		// The stub to resolver DNS traffic follows a simple pattern in which
		// the client sends a query, and the server provides a response.  This
		// design specifies that for each subsequent query on a QUIC connection
		// the client MUST select the next available client-initiated
		// bidirectional stream.
		stream, err := conn.AcceptStream(context.Background())
		if err != nil {
			if isQUICErrorForDebugLog(err) {
				log.Debug("accepting quic stream: closed or timed out: %s", err)
			} else {
				log.Error("accepting quic stream: %s", err)
			}

			// Close the connection to make sure resources are freed.
			closeQUICConn(conn, DoQCodeNoError)

			return
		}

		requestGoroutinesSema.acquire()
		go func() {
			p.handleQUICStream(stream, conn)

			// The server MUST send the response(s) on the same stream and MUST
			// indicate, after the last response, through the STREAM FIN
			// mechanism that no further data will be sent on that stream.
			_ = stream.Close()

			requestGoroutinesSema.release()
		}()
	}
}

// handleQUICStream reads DNS queries from the stream, processes them,
// and writes back the response.
func (p *Proxy) handleQUICStream(stream quic.Stream, conn quic.Connection) {
	bufPtr := p.bytesPool.Get().(*[]byte)
	defer p.bytesPool.Put(bufPtr)

	// One query - one stream.
	// The client MUST select the next available client-initiated bidirectional
	// stream for each subsequent query on a QUIC connection.

	// err is not checked here because STREAM FIN sent by the client is
	// indicated as error here.  Instead, we should check the number of bytes
	// received.
	buf := *bufPtr
	n, err := readAll(stream, buf)

	// Note that io.EOF does not really mean that there's any error, this is
	// just a signal that there will be no data to read anymore from this
	// stream.
	if (err != nil && err != io.EOF) || n < minDNSPacketSize {
		logShortQUICRead(err)

		return
	}

	// In theory, we should use ALPN to get the DoQ version properly. However,
	// since there are not too many versions now, we only check how the DNS
	// query is encoded. If it's sent with a 2-byte prefix, we consider this a
	// DoQ v1. Otherwise, a draft version.
	doqVersion := DoQv1
	req := &dns.Msg{}

	// Note that we support both the old drafts and the new RFC. In the old
	// draft DNS messages were not prefixed with the message length.
	packetLen := binary.BigEndian.Uint16(buf[:2])
	if packetLen == uint16(n-2) {
		err = req.Unpack(buf[2:])
	} else {
		err = req.Unpack(buf)
		doqVersion = DoQv1Draft
	}

	if err != nil {
		log.Error("unpacking quic packet: %s", err)
		closeQUICConn(conn, DoQCodeProtocolError)

		return
	}

	if !validQUICMsg(req) {
		// If a peer encounters such an error condition, it is considered a
		// fatal error. It SHOULD forcibly abort the connection using QUIC's
		// CONNECTION_CLOSE mechanism and SHOULD use the DoQ error code
		// DOQ_PROTOCOL_ERROR.
		closeQUICConn(conn, DoQCodeProtocolError)

		return
	}

	d := p.newDNSContext(ProtoQUIC, req)
	d.Addr = conn.RemoteAddr()
	d.QUICStream = stream
	d.QUICConnection = conn
	d.DoQVersion = doqVersion

	err = p.handleDNSRequest(d)
	if err != nil {
		log.Tracef("error handling DNS (%s) request: %s", d.Proto, err)
	}
}

// respondQUIC writes a response to the QUIC stream.
func (p *Proxy) respondQUIC(d *DNSContext) error {
	resp := d.Res

	if resp == nil {
		// If no response has been written, close the QUIC connection now.
		closeQUICConn(d.QUICConnection, DoQCodeInternalError)

		return errors.Error("no response to write")
	}

	bytes, err := resp.Pack()
	if err != nil {
		return fmt.Errorf("couldn't convert message into wire format: %w", err)
	}

	// Depending on the DoQ version with either write a 2-bytes prefixed message
	// or just write the message (for old draft versions).
	var buf []byte
	switch d.DoQVersion {
	case DoQv1:
		buf = proxyutil.AddPrefix(bytes)
	case DoQv1Draft:
		buf = bytes
	default:
		return fmt.Errorf("invalid protocol version: %d", d.DoQVersion)
	}

	n, err := d.QUICStream.Write(buf)
	if err != nil {
		return fmt.Errorf("conn.Write(): %w", err)
	}
	if n != len(buf) {
		return fmt.Errorf("conn.Write() returned with %d != %d", n, len(buf))
	}

	return nil
}

// validQUICMsg validates the incoming DNS message and returns false if
// something is wrong with the message.
func validQUICMsg(req *dns.Msg) (ok bool) {
	// See https://www.rfc-editor.org/rfc/rfc9250.html#name-protocol-errors

	// 1. a client or server receives a message with a non-zero Message ID.
	//
	// We do consciously not validate this case since there are stub proxies
	// that are sending a non-zero Message IDs.

	// 2. a client or server receives a STREAM FIN before receiving all the
	// bytes for a message indicated in the 2-octet length field.
	// 3. a server receives more than one query on a stream
	//
	// These cases are covered earlier when unpacking the DNS message.

	// 4. the client or server does not indicate the expected STREAM FIN after
	// sending requests or responses (see Section 4.2).
	//
	// This is quite problematic to validate this case since this would imply
	// we have to wait until STREAM FIN is arrived before we start processing
	// the message. So we're consciously ignoring this case in this
	// implementation.

	// 5. an implementation receives a message containing the edns-tcp-keepalive
	// EDNS(0) Option [RFC7828] (see Section 5.5.2).
	if opt := req.IsEdns0(); opt != nil {
		for _, option := range opt.Option {
			// Check for EDNS TCP keepalive option
			if option.Option() == dns.EDNS0TCPKEEPALIVE {
				log.Debug("client sent EDNS0 TCP keepalive option")

				return false
			}
		}
	}

	// 6. a client or a server attempts to open a unidirectional QUIC stream.
	//
	// This case can only be handled when writing a response.

	// 7. a server receives a "replayable" transaction in 0-RTT data
	//
	// The information necessary to validate this is not exposed by quic-go.

	return true
}

// logShortQUICRead is a logging helper for short reads from a QUIC stream.
func logShortQUICRead(err error) {
	if err == nil {
		log.Info("quic packet too short for dns query")

		return
	}

	if isQUICErrorForDebugLog(err) {
		log.Debug("reading from quic stream: closed or timeout: %s", err)
	} else {
		log.Error("reading from quic stream: %s", err)
	}
}

const (
	// qCodeNoError is returned when the QUIC connection was gracefully closed
	// and there is no error to signal.
	qCodeNoError = quic.ApplicationErrorCode(quic.NoError)

	// qCodeApplicationErrorError is used for Initial and Handshake packets.
	// This error is considered as non-critical and will not be logged as error.
	qCodeApplicationErrorError = quic.ApplicationErrorCode(quic.ApplicationErrorErrorCode)
)

// isQUICErrorForDebugLog returns true if err is a non-critical error, most probably
// related to the current QUIC implementation. err must not be nil.
//
// TODO(ameshkov): re-test when updating quic-go.
func isQUICErrorForDebugLog(err error) (ok bool) {
	if errors.Is(err, quic.ErrServerClosed) {
		// This error is returned when the QUIC listener was closed by us. This
		// is an expected error, we don't need the detailed logs here.
		return true
	}

	var qAppErr *quic.ApplicationError
	if errors.As(err, &qAppErr) &&
		(qAppErr.ErrorCode == qCodeNoError || qAppErr.ErrorCode == qCodeApplicationErrorError) {
		// No need to have detailed logs for these error codes either.
		//
		// TODO(a.garipov): Consider adding other error codes.
		return true
	}

	if errors.Is(err, quic.Err0RTTRejected) {
		// This error is returned on AcceptStream calls when the server rejects
		// 0-RTT for some reason.  This is a common scenario, no need for extra
		// logs.
		return true
	}

	// This error is returned when we're trying to accept a new stream from a
	// connection that had no activity for over than the keep-alive timeout.
	// This is a common scenario, no need for extra logs.
	var qIdleErr *quic.IdleTimeoutError

	return errors.As(err, &qIdleErr)
}

// closeQUICConn quietly closes the QUIC connection.
func closeQUICConn(conn quic.Connection, code quic.ApplicationErrorCode) {
	log.Debug("closing quic conn %s with code %d", conn.LocalAddr(), code)

	err := conn.CloseWithError(code, "")
	if err != nil {
		log.Debug("closing quic connection with code %d: %s", code, err)
	}
}

// newServerQUICConfig creates *quic.Config populated with the default settings.
// This function is supposed to be used for both DoQ and DoH3 server.
func newServerQUICConfig() (conf *quic.Config) {
	v := newQUICAddrValidator(quicAddrValidatorCacheSize, quicAddrValidatorCacheTTL)

	return &quic.Config{
		MaxIdleTimeout:           maxQUICIdleTimeout,
		MaxIncomingStreams:       math.MaxUint16,
		MaxIncomingUniStreams:    math.MaxUint16,
		RequireAddressValidation: v.requiresValidation,
		// Enable 0-RTT by default for all connections on the server-side.
		Allow0RTT: true,
	}
}

// quicAddrValidator is a helper struct that holds a small LRU cache of
// addresses for which we do not require address validation.
type quicAddrValidator struct {
	cache gcache.Cache
	ttl   time.Duration
}

// newQUICAddrValidator initializes a new instance of *quicAddrValidator.
func newQUICAddrValidator(cacheSize int, ttl time.Duration) (v *quicAddrValidator) {
	return &quicAddrValidator{
		cache: gcache.New(cacheSize).LRU().Build(),
		ttl:   ttl,
	}
}

// requiresValidation determines if a QUIC Retry packet should be sent by the
// client. This allows the server to verify the client's address but increases
// the latency.
func (v *quicAddrValidator) requiresValidation(addr net.Addr) (ok bool) {
	// addr must be *net.UDPAddr here and if it's not we don't mind panic.
	key := addr.(*net.UDPAddr).IP.String()
	if v.cache.Has(key) {
		return false
	}

	err := v.cache.SetWithExpire(key, true, v.ttl)
	if err != nil {
		// Shouldn't happen, since we don't set a serialization function.
		panic(fmt.Errorf("quic validator: setting cache item: %w", err))
	}

	// Address not found in the cache so return true to make sure the server
	// will require address validation.
	return true
}

// readAll reads from r until an error or io.EOF into the specified buffer buf.
// A successful call returns err == nil, not err == io.EOF.  If the buffer is
// too small, it returns error io.ErrShortBuffer.  This function has some
// similarities to io.ReadAll, but it reads to the specified buffer and not
// allocates (and grows) a new one.  Also, it is completely different from
// io.ReadFull as that one reads the exact number of bytes (buffer length) and
// readAll reads until io.EOF or until the buffer is filled.
func readAll(r io.Reader, buf []byte) (n int, err error) {
	for {
		if n == len(buf) {
			return n, io.ErrShortBuffer
		}

		var read int
		read, err = r.Read(buf[n:])
		n += read

		if err != nil {
			if err == io.EOF {
				err = nil
			}
			return n, err
		}
	}
}
0707010000005F000081A4000000000000000000000001650C5921000013AD000000000000000000000000000000000000002A00000000dnsproxy-0.55.0/proxy/server_quic_test.gopackage proxy

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"io"
	"net"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/stretchr/testify/require"
)

func TestQuicProxy(t *testing.T) {
	// Prepare the proxy server.
	serverConfig, caPem := createServerTLSConfig(t)
	dnsProxy := createTestProxy(t, serverConfig)
	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	// Start listening.
	err := dnsProxy.Start()
	require.NoError(t, err)

	roots := x509.NewCertPool()
	roots.AppendCertsFromPEM(caPem)
	tlsConfig := &tls.Config{
		ServerName: tlsServerName,
		RootCAs:    roots,
		NextProtos: append([]string{NextProtoDQ}, compatProtoDQ...),
	}

	// Create a DNS-over-QUIC client connection.
	addr := dnsProxy.Addr(ProtoQUIC)

	// Open a QUIC connection.
	conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, nil)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) {
		return conn.CloseWithError(DoQCodeNoError, "")
	})

	// Send several test messages.
	for i := 0; i < 10; i++ {
		sendTestQUICMessage(t, conn, DoQv1)

		// Send a message encoded for a draft version as well.
		sendTestQUICMessage(t, conn, DoQv1Draft)
	}
}

func TestQuicProxy_largePackets(t *testing.T) {
	// Prepare the proxy server.
	serverConfig, caPem := createServerTLSConfig(t)
	dnsProxy := createTestProxy(t, serverConfig)

	// Make sure the request does not go to any real upstream.
	dnsProxy.RequestHandler = func(p *Proxy, d *DNSContext) (err error) {
		resp := &dns.Msg{}
		resp.SetReply(d.Req)
		resp.Answer = []dns.RR{
			&dns.A{
				Hdr: dns.RR_Header{
					Name:   d.Req.Question[0].Name,
					Rrtype: dns.TypeA,
					Class:  dns.ClassINET,
				},
				A: net.IP{8, 8, 8, 8},
			},
		}
		d.Res = resp

		return nil
	}

	testutil.CleanupAndRequireSuccess(t, dnsProxy.Stop)

	// Start listening.
	err := dnsProxy.Start()
	require.NoError(t, err)

	roots := x509.NewCertPool()
	roots.AppendCertsFromPEM(caPem)
	tlsConfig := &tls.Config{
		ServerName: tlsServerName,
		RootCAs:    roots,
		NextProtos: append([]string{NextProtoDQ}, compatProtoDQ...),
	}

	// Create a DNS-over-QUIC client connection.
	addr := dnsProxy.Addr(ProtoQUIC)

	// Open a QUIC connection.
	conn, err := quic.DialAddrEarly(context.Background(), addr.String(), tlsConfig, nil)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, func() (err error) {
		return conn.CloseWithError(DoQCodeNoError, "")
	})

	// Create a test message large enough to take multiple QUIC frames.
	msg := createTestMessage()
	msg.Extra = []dns.RR{
		&dns.OPT{
			Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT, Class: 4096},
			Option: []dns.EDNS0{
				&dns.EDNS0_PADDING{Padding: make([]byte, 4096)},
			},
		},
	}

	resp := sendQUICMessage(t, msg, conn, DoQv1)
	requireResponse(t, msg, resp)
}

// sendQUICMessage sends msg to the specified QUIC connection.
func sendQUICMessage(
	t *testing.T,
	msg *dns.Msg,
	conn quic.Connection,
	doqVersion DoQVersion,
) (resp *dns.Msg) {
	// Open a new QUIC stream to write there a test DNS query.
	stream, err := conn.OpenStreamSync(context.Background())
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, stream.Close)

	packedMsg, err := msg.Pack()
	require.NoError(t, err)

	buf := packedMsg
	if doqVersion == DoQv1 {
		buf = proxyutil.AddPrefix(packedMsg)
	}

	// Send the DNS query to the stream.
	err = writeQUICStream(buf, stream)
	require.NoError(t, err)

	// Close closes the write-direction of the stream and sends
	// a STREAM FIN packet.
	_ = stream.Close()

	// Now read the response from the stream.
	respBytes := make([]byte, 64*1024)
	n, err := stream.Read(respBytes)
	if err != nil {
		require.ErrorIs(t, err, io.EOF)
	}
	require.Greater(t, n, minDNSPacketSize)

	// Unpack the DNS response.
	resp = new(dns.Msg)
	if doqVersion == DoQv1 {
		err = resp.Unpack(respBytes[2:])
	} else {
		err = resp.Unpack(respBytes)
	}
	require.NoError(t, err)

	return resp
}

// writeQUICStream writes buf to the specified QUIC stream in chunks.  This way
// it is possible to test how the server deals with chunked DNS messages.
func writeQUICStream(buf []byte, stream quic.Stream) (err error) {
	// Send the DNS query to the stream and split it into chunks of up
	// to 400 bytes.  400 is an arbitrary chosen value.
	chunkSize := 400
	for i := 0; i < len(buf); i += chunkSize {
		chunkStart := i
		chunkEnd := i + chunkSize
		if chunkEnd > len(buf) {
			chunkEnd = len(buf)
		}

		_, err = stream.Write(buf[chunkStart:chunkEnd])
		if err != nil {
			return err
		}

		if len(buf) > chunkSize {
			// Emulate network latency.
			time.Sleep(time.Millisecond)
		}
	}

	return nil
}

// sendTestQUICMessage send a test message to the specified QUIC connection.
func sendTestQUICMessage(t *testing.T, conn quic.Connection, doqVersion DoQVersion) {
	msg := createTestMessage()
	resp := sendQUICMessage(t, msg, conn, doqVersion)
	requireResponse(t, msg, resp)
}
07070100000060000081A4000000000000000000000001650C59210000142D000000000000000000000000000000000000002400000000dnsproxy-0.55.0/proxy/server_tcp.gopackage proxy

import (
	"context"
	"crypto/tls"
	"encoding/binary"
	"fmt"
	"io"
	"net"
	"time"

	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
)

func (p *Proxy) createTCPListeners(ctx context.Context) (err error) {
	for _, a := range p.TCPListenAddr {
		log.Info("dnsproxy: creating tcp server socket %s", a)

		lsnr, lErr := proxynetutil.ListenConfig().Listen(ctx, "tcp", a.String())
		if lErr != nil {
			return fmt.Errorf("listening to tcp socket: %w", lErr)
		}

		tcpListener, ok := lsnr.(*net.TCPListener)
		if !ok {
			return fmt.Errorf("wrong listener type on tcp addr %s: %T", a, lsnr)
		}

		p.tcpListen = append(p.tcpListen, tcpListener)

		log.Info("dnsproxy: listening to tcp://%s", tcpListener.Addr())
	}

	return nil
}

func (p *Proxy) createTLSListeners() (err error) {
	for _, a := range p.TLSListenAddr {
		log.Info("dnsproxy: creating tls server socket %s", a)

		var tcpListen *net.TCPListener
		tcpListen, err = net.ListenTCP("tcp", a)
		if err != nil {
			return fmt.Errorf("listening on tls addr %s: %w", a, err)
		}

		l := tls.NewListener(tcpListen, p.TLSConfig)
		p.tlsListen = append(p.tlsListen, l)

		log.Info("dnsproxy: listening to tls://%s", l.Addr())
	}

	return nil
}

// tcpPacketLoop listens for incoming TCP packets.  proto must be either "tcp"
// or "tls".
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) tcpPacketLoop(l net.Listener, proto Proto, requestGoroutinesSema semaphore) {
	log.Info("dnsproxy: entering %s listener loop on %s", proto, l.Addr())

	for {
		clientConn, err := l.Accept()
		if err != nil {
			if errors.Is(err, net.ErrClosed) {
				log.Debug("dnsproxy: tcp connection %s closed", l.Addr())
			} else {
				log.Error("dnsproxy: reading from tcp: %s", err)
			}

			break
		}

		requestGoroutinesSema.acquire()
		go func() {
			p.handleTCPConnection(clientConn, proto)
			requestGoroutinesSema.release()
		}()
	}
}

// handleTCPConnection starts a loop that handles an incoming TCP connection.
// proto must be either ProtoTCP or ProtoTLS.
func (p *Proxy) handleTCPConnection(conn net.Conn, proto Proto) {
	defer log.OnPanic("proxy.handleTCPConnection")

	log.Debug("dnsproxy: handling new %s request from %s", proto, conn.RemoteAddr())

	defer func() {
		err := conn.Close()
		if err != nil {
			logWithNonCrit(err, "dnsproxy: handling tcp: closing conn")
		}
	}()

	for {
		p.RLock()
		if !p.started {
			return
		}
		p.RUnlock()

		err := conn.SetDeadline(time.Now().Add(defaultTimeout))
		if err != nil {
			// Consider deadline errors non-critical.
			logWithNonCrit(err, "handling tcp: setting deadline")
		}

		packet, err := readPrefixed(conn)
		if err != nil {
			logWithNonCrit(err, "handling tcp: reading msg")

			break
		}

		req := &dns.Msg{}
		err = req.Unpack(packet)
		if err != nil {
			log.Error("dnsproxy: handling tcp: unpacking msg: %s", err)

			return
		}

		d := p.newDNSContext(proto, req)
		d.Addr = conn.RemoteAddr()
		d.Conn = conn

		err = p.handleDNSRequest(d)
		if err != nil {
			logWithNonCrit(err, fmt.Sprintf("handling tcp: handling %s request", d.Proto))
		}
	}
}

// errTooLarge means that a DNS message is larger than 64KiB.
const errTooLarge errors.Error = "dns message is too large"

// readPrefixed reads a DNS message with a 2-byte prefix containing message
// length from conn.
func readPrefixed(conn net.Conn) (b []byte, err error) {
	l := make([]byte, 2)
	_, err = conn.Read(l)
	if err != nil {
		return nil, fmt.Errorf("reading len: %w", err)
	}

	packetLen := binary.BigEndian.Uint16(l)
	if packetLen > dns.MaxMsgSize {
		return nil, errTooLarge
	}

	b = make([]byte, packetLen)
	_, err = io.ReadFull(conn, b)
	if err != nil {
		return nil, fmt.Errorf("reading msg: %w", err)
	}

	return b, nil
}

// logWithNonCrit logs the error on the appropriate level depending on whether
// err is a critical error or not.
func logWithNonCrit(err error, msg string) {
	if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || isEPIPE(err) {
		log.Debug("%s: connection is closed; original error: %s", msg, err)
	} else if netErr := net.Error(nil); errors.As(err, &netErr) && netErr.Timeout() {
		log.Debug("%s: connection timed out; original error: %s", msg, err)
	} else {
		log.Error("%s: %s", msg, err)
	}
}

// Writes a response to the TCP (or TLS) client
func (p *Proxy) respondTCP(d *DNSContext) error {
	resp := d.Res
	conn := d.Conn

	if resp == nil {
		// If no response has been written, close the connection right away
		return conn.Close()
	}

	bytes, err := resp.Pack()
	if err != nil {
		return fmt.Errorf("packing message: %w", err)
	}

	err = writePrefixed(bytes, conn)
	if err != nil && !errors.Is(err, net.ErrClosed) {
		return fmt.Errorf("writing message: %w", err)
	}

	return nil
}

// writePrefixed writes a DNS message to a TCP connection it first writes
// a 2-byte prefix followed by the message itself.
func writePrefixed(b []byte, conn net.Conn) (err error) {
	l := make([]byte, 2)
	binary.BigEndian.PutUint16(l, uint16(len(b)))
	_, err = (&net.Buffers{l, b}).WriteTo(conn)

	return err
}
07070100000061000081A4000000000000000000000001650C592100000581000000000000000000000000000000000000002900000000dnsproxy-0.55.0/proxy/server_tcp_test.gopackage proxy

import (
	"crypto/tls"
	"crypto/x509"
	"testing"

	"github.com/miekg/dns"
)

func TestTcpProxy(t *testing.T) {
	// Prepare the proxy server
	dnsProxy := createTestProxy(t, nil)

	// Start listening
	err := dnsProxy.Start()
	if err != nil {
		t.Fatalf("cannot start the DNS proxy: %s", err)
	}

	// Create a DNS-over-TCP client connection
	addr := dnsProxy.Addr(ProtoTCP)
	conn, err := dns.Dial("tcp", addr.String())
	if err != nil {
		t.Fatalf("cannot connect to the proxy: %s", err)
	}

	sendTestMessages(t, conn)

	// Stop the proxy
	err = dnsProxy.Stop()
	if err != nil {
		t.Fatalf("cannot stop the DNS proxy: %s", err)
	}
}

func TestTlsProxy(t *testing.T) {
	// Prepare the proxy server
	serverConfig, caPem := createServerTLSConfig(t)
	dnsProxy := createTestProxy(t, serverConfig)

	// Start listening
	err := dnsProxy.Start()
	if err != nil {
		t.Fatalf("cannot start the DNS proxy: %s", err)
	}

	roots := x509.NewCertPool()
	roots.AppendCertsFromPEM(caPem)
	tlsConfig := &tls.Config{ServerName: tlsServerName, RootCAs: roots}

	// Create a DNS-over-TLS client connection
	addr := dnsProxy.Addr(ProtoTLS)
	conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig)
	if err != nil {
		t.Fatalf("cannot connect to the proxy: %s", err)
	}

	sendTestMessages(t, conn)

	// Stop the proxy
	err = dnsProxy.Stop()
	if err != nil {
		t.Fatalf("cannot stop the DNS proxy: %s", err)
	}
}
07070100000062000081A4000000000000000000000001650C592100000EC3000000000000000000000000000000000000002400000000dnsproxy-0.55.0/proxy/server_udp.gopackage proxy

import (
	"context"
	"fmt"
	"net"

	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
)

func (p *Proxy) createUDPListeners(ctx context.Context) (err error) {
	for _, a := range p.UDPListenAddr {
		var pc *net.UDPConn
		pc, sErr := p.udpCreate(ctx, a)
		if sErr != nil {
			return fmt.Errorf("listening on udp addr %s: %w", a, sErr)
		}

		p.udpListen = append(p.udpListen, pc)
	}

	return nil
}

// udpCreate - create a UDP listening socket
func (p *Proxy) udpCreate(ctx context.Context, udpAddr *net.UDPAddr) (*net.UDPConn, error) {
	log.Info("dnsproxy: creating udp server socket %s", udpAddr)

	packetConn, err := proxynetutil.ListenConfig().ListenPacket(ctx, "udp", udpAddr.String())
	if err != nil {
		return nil, fmt.Errorf("listening to udp socket: %w", err)
	}

	udpListen := packetConn.(*net.UDPConn)
	if p.Config.UDPBufferSize > 0 {
		err = udpListen.SetReadBuffer(p.Config.UDPBufferSize)
		if err != nil {
			_ = udpListen.Close()

			return nil, fmt.Errorf("setting udp buf size: %w", err)
		}
	}

	err = proxynetutil.UDPSetOptions(udpListen)
	if err != nil {
		_ = udpListen.Close()

		return nil, fmt.Errorf("setting udp opts: %w", err)
	}

	log.Info("dnsproxy: listening to udp://%s", udpListen.LocalAddr())

	return udpListen, nil
}

// udpPacketLoop listens for incoming UDP packets.
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) udpPacketLoop(conn *net.UDPConn, requestGoroutinesSema semaphore) {
	log.Info("dnsproxy: entering udp listener loop on %s", conn.LocalAddr())

	b := make([]byte, dns.MaxMsgSize)
	for {
		p.RLock()
		if !p.started {
			return
		}
		p.RUnlock()

		n, localIP, remoteAddr, err := proxynetutil.UDPRead(conn, b, p.udpOOBSize)
		// documentation says to handle the packet even if err occurs, so do that first
		if n > 0 {
			// make a copy of all bytes because ReadFrom() will overwrite contents of b on next call
			// we need the contents to survive the call because we're handling them in goroutine
			packet := make([]byte, n)
			copy(packet, b)
			requestGoroutinesSema.acquire()
			go func() {
				p.udpHandlePacket(packet, localIP, remoteAddr, conn)
				requestGoroutinesSema.release()
			}()
		}
		if err != nil {
			if errors.Is(err, net.ErrClosed) {
				log.Debug("dnsproxy: udp connection %s closed", conn.LocalAddr())
			} else {
				log.Error("dnsproxy: reading from udp: %s", err)
			}

			break
		}
	}
}

// udpHandlePacket processes the incoming UDP packet and sends a DNS response
func (p *Proxy) udpHandlePacket(packet []byte, localIP net.IP, remoteAddr *net.UDPAddr, conn *net.UDPConn) {
	log.Debug("dnsproxy: handling new udp packet from %s", remoteAddr)

	req := &dns.Msg{}
	err := req.Unpack(packet)
	if err != nil {
		log.Error("dnsproxy: unpacking udp packet: %s", err)

		return
	}

	d := p.newDNSContext(ProtoUDP, req)
	d.Addr = remoteAddr
	d.Conn = conn
	d.localIP = localIP

	err = p.handleDNSRequest(d)
	if err != nil {
		log.Debug("dnsproxy: handling dns (proto %s) request: %s", d.Proto, err)
	}
}

// Writes a response to the UDP client
func (p *Proxy) respondUDP(d *DNSContext) error {
	resp := d.Res

	if resp == nil {
		// Do nothing if no response has been written
		return nil
	}

	bytes, err := resp.Pack()
	if err != nil {
		return fmt.Errorf("packing message: %w", err)
	}

	conn := d.Conn.(*net.UDPConn)
	rAddr := d.Addr.(*net.UDPAddr)
	n, err := proxynetutil.UDPWrite(bytes, conn, rAddr, d.localIP)
	if err != nil {
		if errors.Is(err, net.ErrClosed) {
			return nil
		}

		return fmt.Errorf("writing message: %w", err)
	}

	if n != len(bytes) {
		return fmt.Errorf("udpWrite() returned with %d != %d", n, len(bytes))
	}

	return nil
}
07070100000063000081A4000000000000000000000001650C592100000267000000000000000000000000000000000000002900000000dnsproxy-0.55.0/proxy/server_udp_test.gopackage proxy

import (
	"testing"

	"github.com/miekg/dns"
)

func TestUdpProxy(t *testing.T) {
	// Prepare the proxy server
	dnsProxy := createTestProxy(t, nil)

	// Start listening
	err := dnsProxy.Start()
	if err != nil {
		t.Fatalf("cannot start the DNS proxy: %s", err)
	}

	// Create a DNS-over-UDP client connection
	addr := dnsProxy.Addr(ProtoUDP)
	conn, err := dns.Dial("udp", addr.String())
	if err != nil {
		t.Fatalf("cannot connect to the proxy: %s", err)
	}

	sendTestMessages(t, conn)

	// Stop the proxy
	err = dnsProxy.Stop()
	if err != nil {
		t.Fatalf("cannot stop the DNS proxy: %s", err)
	}
}
07070100000064000081A4000000000000000000000001650C59210000257F000000000000000000000000000000000000002300000000dnsproxy-0.55.0/proxy/upstreams.gopackage proxy

import (
	"fmt"
	"io"
	"strings"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/stringutil"
	"golang.org/x/exp/slices"
)

// UpstreamConfig is a wrapper for a list of default upstreams, a map of
// reserved domains and corresponding upstreams.
type UpstreamConfig struct {
	// DomainReservedUpstreams is a map of reserved domains and lists of
	// corresponding upstreams.
	DomainReservedUpstreams map[string][]upstream.Upstream

	// SpecifiedDomainUpstreams is a map of excluded domains and lists of
	// corresponding upstreams.
	SpecifiedDomainUpstreams map[string][]upstream.Upstream

	// SubdomainExclusions is set of domains with subdomains exclusions.
	SubdomainExclusions *stringutil.Set

	// Upstreams is a list of default upstreams.
	Upstreams []upstream.Upstream
}

// type check
var _ io.Closer = (*UpstreamConfig)(nil)

// ParseUpstreamsConfig returns UpstreamConfig and error if upstreams configuration is invalid
// default upstream syntax: <upstreamString>
// reserved upstream syntax: [/domain1/../domainN/]<upstreamString>
// subdomains only upstream syntax: [/*.domain1/../*.domainN]<upstreamString>
// More specific domains take priority over less specific domains,
// To exclude more specific domains from reserved upstreams querying you should use the following syntax: [/domain1/../domainN/]#
// So the following config: ["[/host.com/]1.2.3.4", "[/www.host.com/]2.3.4.5", "[/maps.host.com/]#", "3.4.5.6"]
// will send queries for *.host.com to 1.2.3.4, except for *.www.host.com, which will go to 2.3.4.5 and *.maps.host.com,
// which will go to default server 3.4.5.6 with all other domains.
// To exclude top level domain from reserved upstreams querying you could use the following: [/*.domain.com/]<upstreamString>
// So the following config: ["[/*.domain.com/]1.2.3.4", "3.4.5.6"] will send queries for all subdomains *.domain.com to 1.2.3.4,
// but domain.com query will be sent to default server 3.4.5.6 as every other query.
//
// TODO(e.burkov):  Refactor this mess.
func ParseUpstreamsConfig(upstreamConfig []string, options *upstream.Options) (*UpstreamConfig, error) {
	if options == nil {
		options = &upstream.Options{}
	}

	if len(options.Bootstrap) > 0 {
		log.Debug("Bootstraps: %v", options.Bootstrap)
	}

	var upstreams []upstream.Upstream
	// We use this index to avoid creating duplicates of upstreams
	upstreamsIndex := map[string]upstream.Upstream{}

	domainReservedUpstreams := map[string][]upstream.Upstream{}
	specifiedDomainUpstreams := map[string][]upstream.Upstream{}
	subdomainsOnlyUpstreams := map[string][]upstream.Upstream{}
	subdomainsOnlyExclusions := stringutil.NewSet()

	for i, l := range upstreamConfig {
		u, hosts, err := parseUpstreamLine(l)
		if err != nil {
			return &UpstreamConfig{}, err
		}

		// # excludes more specific domain from reserved upstreams querying
		if u == "#" && len(hosts) > 0 {
			for _, host := range hosts {
				if strings.HasPrefix(host, "*.") {
					host = host[len("*."):]

					subdomainsOnlyExclusions.Add(host)
					subdomainsOnlyUpstreams[host] = nil
				} else {
					domainReservedUpstreams[host] = nil
					specifiedDomainUpstreams[host] = nil
				}
			}
		} else {
			dnsUpstream, ok := upstreamsIndex[u]
			if !ok {
				// create an upstream
				dnsUpstream, err = upstream.AddressToUpstream(u, options.Clone())

				if err != nil {
					err = fmt.Errorf("cannot prepare the upstream %s (%s): %s", l, options.Bootstrap, err)

					return &UpstreamConfig{}, err
				}

				// save to the index
				upstreamsIndex[u] = dnsUpstream
			}

			if len(hosts) == 0 {
				log.Debug("Upstream %d: %s", i, dnsUpstream.Address())
				upstreams = append(upstreams, dnsUpstream)

				continue
			}

			for _, host := range hosts {
				if strings.HasPrefix(host, "*.") {
					host = host[len("*."):]

					subdomainsOnlyExclusions.Add(host)
					log.Debug("domain %s is added to exclusions list", host)

					subdomainsOnlyUpstreams[host] = append(subdomainsOnlyUpstreams[host], dnsUpstream)
				} else {
					specifiedDomainUpstreams[host] = append(specifiedDomainUpstreams[host], dnsUpstream)
				}

				domainReservedUpstreams[host] = append(domainReservedUpstreams[host], dnsUpstream)
			}

			log.Debug("Upstream %d: %s is reserved for next domains: %s",
				i, dnsUpstream.Address(), strings.Join(hosts, ", "))
		}
	}

	for host, ups := range subdomainsOnlyUpstreams {
		// Rewrite ups for wildcard subdomains to remove upper level domains specs.
		domainReservedUpstreams[host] = ups
	}

	return &UpstreamConfig{
		Upstreams:                upstreams,
		DomainReservedUpstreams:  domainReservedUpstreams,
		SpecifiedDomainUpstreams: specifiedDomainUpstreams,
		SubdomainExclusions:      subdomainsOnlyExclusions,
	}, nil
}

// errNoDefaultUpstreams is returned when no default upstreams specified within
// a [Config.UpstreamConfig].
const errNoDefaultUpstreams errors.Error = "no default upstreams specified"

// validate returns an error if the upstreams aren't configured properly.  c
// considered valid if it contains at least a single default upstream.  Nil c,
// as well as c with no default upstreams causes [ErrNoDefaultUpstreams].  Empty
// c causes [upstream.ErrNoUpstreams].
func (uc *UpstreamConfig) validate() (err error) {
	switch {
	case uc == nil:
		return fmt.Errorf("%w; uc is nil", errNoDefaultUpstreams)
	case len(uc.Upstreams) > 0:
		return nil
	case len(uc.DomainReservedUpstreams) == 0 && len(uc.SpecifiedDomainUpstreams) == 0:
		return upstream.ErrNoUpstreams
	default:
		return errNoDefaultUpstreams
	}
}

// parseUpstreamLine - parses upstream line and returns the following:
// upstream address
// list of domains for which this upstream is reserved (may be nil)
// error if something went wrong
func parseUpstreamLine(l string) (string, []string, error) {
	var hosts []string
	u := l

	if strings.HasPrefix(l, "[/") {
		// split domains and upstream string
		domainsAndUpstream := strings.Split(strings.TrimPrefix(l, "[/"), "/]")
		if len(domainsAndUpstream) != 2 {
			return "", nil, fmt.Errorf("wrong upstream specification: %s", l)
		}

		// split domains list
		for _, confHost := range strings.Split(domainsAndUpstream[0], "/") {
			if confHost != "" {
				host := strings.TrimPrefix(confHost, "*.")
				if err := netutil.ValidateDomainName(host); err != nil {
					return "", nil, err
				}

				hosts = append(hosts, strings.ToLower(confHost+"."))
			} else {
				// empty domain specification means `unqualified names only`
				hosts = append(hosts, UnqualifiedNames)
			}
		}
		u = domainsAndUpstream[1]
	}

	return u, hosts, nil
}

// getUpstreamsForDomain looks for a domain in the reserved domains map and
// returns a list of corresponding upstreams.  It returns default upstreams list
// if the domain was not found in the map.  More specific domains take priority
// over less specific domains.  For example, take a map that contains the
// following keys: host.com and www.host.com.  If we are looking for domain
// mail.host.com, this method will return value of host.com key.  If we are
// looking for domain www.host.com, this method will return value of the
// www.host.com key.  If a more specific domain value is nil, it means that the
// domain was excluded and should be exchanged with default upstreams.
func (uc *UpstreamConfig) getUpstreamsForDomain(host string) (ups []upstream.Upstream) {
	if len(uc.DomainReservedUpstreams) == 0 {
		return uc.Upstreams
	}

	var ok bool

	dotsCount := strings.Count(host, ".")
	if dotsCount < 2 {
		host = UnqualifiedNames
	} else {
		host = strings.ToLower(host)
		if uc.SubdomainExclusions.Has(host) {
			return uc.lookupSubdomainExclusion(host)
		}
	}

	for i := 1; i <= dotsCount; i++ {
		h := strings.SplitAfterN(host, ".", i)
		name := h[i-1]

		ups, ok = uc.lookupUpstreams(name)
		if !ok {
			continue
		}

		return ups
	}

	return uc.Upstreams
}

// lookupSubdomainExclusion returns upstreams for the host from subdomain
// exclusions list.
func (uc *UpstreamConfig) lookupSubdomainExclusion(host string) (u []upstream.Upstream) {
	ups, ok := uc.SpecifiedDomainUpstreams[host]
	if ok && len(ups) > 0 {
		return ups
	}

	// Check if there is a spec for upper level domain.
	h := strings.SplitAfterN(host, ".", 2)
	ups, ok = uc.DomainReservedUpstreams[h[1]]
	if ok && len(ups) > 0 {
		return ups
	}

	return uc.Upstreams
}

// lookupUpstreams returns upstreams for a domain name.  Returns default
// upstream list for domain name excluded by domain reserved upstreams.
func (uc *UpstreamConfig) lookupUpstreams(name string) (ups []upstream.Upstream, ok bool) {
	ups, ok = uc.DomainReservedUpstreams[name]
	if !ok {
		return ups, false
	}

	if len(ups) == 0 {
		// The domain has been excluded from reserved upstreams querying.
		return uc.Upstreams, true
	}

	return ups, true
}

// Close implements the io.Closer interface for *UpstreamConfig.
func (uc *UpstreamConfig) Close() (err error) {
	closeErrs := closeAll(nil, uc.Upstreams...)

	for _, specUps := range []map[string][]upstream.Upstream{
		uc.DomainReservedUpstreams,
		uc.SpecifiedDomainUpstreams,
	} {
		domains := make([]string, 0, len(specUps))
		for domain := range specUps {
			domains = append(domains, domain)
		}

		slices.SortStableFunc(domains, strings.Compare)

		for _, domain := range domains {
			closeErrs = closeAll(closeErrs, specUps[domain]...)
		}
	}

	if len(closeErrs) > 0 {
		return errors.List("failed to close some upstreams", closeErrs...)
	}

	return nil
}
07070100000065000081A4000000000000000000000001650C592100001AE9000000000000000000000000000000000000002800000000dnsproxy-0.55.0/proxy/upstreams_test.gopackage proxy

import (
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/upstream"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestGetUpstreamsForDomain(t *testing.T) {
	upstreams := []string{
		"[/google.com/local/]4.3.2.1",
		"[/www.google.com//]1.2.3.4",
		"[/maps.google.com/]#",
		"[/www.google.com/]tls://1.1.1.1",
		"[/_acme-challenge.example.org/]#",
	}

	config, err := ParseUpstreamsConfig(
		upstreams,
		&upstream.Options{
			InsecureSkipVerify: false,
			Bootstrap:          []string{},
			Timeout:            1 * time.Second,
		},
	)
	require.NoError(t, err)

	assertUpstreamsForDomain(t, config, "www.google.com.", []string{"1.2.3.4:53", "tls://1.1.1.1:853"})
	assertUpstreamsForDomain(t, config, "www2.google.com.", []string{"4.3.2.1:53"})
	assertUpstreamsForDomain(t, config, "internal.local.", []string{"4.3.2.1:53"})
	assertUpstreamsForDomain(t, config, "google.", []string{"1.2.3.4:53"})
	assertUpstreamsForDomain(t, config, "_acme-challenge.example.org.", []string{})
	assertUpstreamsForDomain(t, config, "maps.google.com.", []string{})
}

func TestUpstreamConfig_Validate(t *testing.T) {
	testCases := []struct {
		name            string
		wantValidateErr error
		in              []string
	}{{
		name:            "empty",
		wantValidateErr: upstream.ErrNoUpstreams,
		in:              []string{},
	}, {
		name:            "nil",
		wantValidateErr: upstream.ErrNoUpstreams,
		in:              nil,
	}, {
		name:            "valid",
		wantValidateErr: nil,
		in: []string{
			"udp://upstream.example:53",
		},
	}, {
		name:            "no_default",
		wantValidateErr: errNoDefaultUpstreams,
		in: []string{
			"[/domain.example/]udp://upstream.example:53",
			"[/another.domain.example/]#",
		},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			c, err := ParseUpstreamsConfig(tc.in, nil)
			require.NoError(t, err)

			assert.ErrorIs(t, c.validate(), tc.wantValidateErr)
		})
	}

	t.Run("actual_nil", func(t *testing.T) {
		assert.ErrorIs(t, (*UpstreamConfig)(nil).validate(), errNoDefaultUpstreams)
	})
}

func TestGetUpstreamsForDomainWithoutDuplicates(t *testing.T) {
	upstreams := []string{"[/example.com/]1.1.1.1", "[/example.org/]1.1.1.1"}
	config, err := ParseUpstreamsConfig(
		upstreams,
		&upstream.Options{
			InsecureSkipVerify: false,
			Bootstrap:          []string{},
			Timeout:            1 * time.Second,
		},
	)
	assert.NoError(t, err)
	assert.Len(t, config.Upstreams, 0)
	assert.Len(t, config.DomainReservedUpstreams, 2)

	u1 := config.DomainReservedUpstreams["example.com."][0]
	u2 := config.DomainReservedUpstreams["example.org."][0]

	// Check that the very same Upstream instance is used for both domains.
	assert.Same(t, u1, u2)
}

func TestGetUpstreamsForDomain_wildcards(t *testing.T) {
	conf := []string{
		"0.0.0.1",
		"[/a.x/]0.0.0.2",
		"[/*.a.x/]0.0.0.3",
		"[/b.a.x/]0.0.0.4",
		"[/*.b.a.x/]0.0.0.5",
		"[/*.x.z/]0.0.0.6",
		"[/c.b.a.x/]#",
	}

	uconf, err := ParseUpstreamsConfig(conf, nil)
	require.NoError(t, err)

	testCases := []struct {
		name string
		in   string
		want []string
	}{{
		name: "default",
		in:   "d.x.",
		want: []string{"0.0.0.1:53"},
	}, {
		name: "specified_one",
		in:   "a.x.",
		want: []string{"0.0.0.2:53"},
	}, {
		name: "wildcard",
		in:   "c.a.x.",
		want: []string{"0.0.0.3:53"},
	}, {
		name: "specified_two",
		in:   "b.a.x.",
		want: []string{"0.0.0.4:53"},
	}, {
		name: "wildcard_two",
		in:   "d.b.a.x.",
		want: []string{"0.0.0.5:53"},
	}, {
		name: "specified_three",
		in:   "c.b.a.x.",
		want: []string{"0.0.0.1:53"},
	}, {
		name: "specified_four",
		in:   "d.c.b.a.x.",
		want: []string{"0.0.0.1:53"},
	}, {
		name: "unspecified_wildcard",
		in:   "x.z.",
		want: []string{"0.0.0.1:53"},
	}, {
		name: "unspecified_wildcard_sub",
		in:   "a.x.z.",
		want: []string{"0.0.0.6:53"},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			assertUpstreamsForDomain(t, uconf, tc.in, tc.want)
		})
	}
}

func TestGetUpstreamsForDomain_sub_wildcards(t *testing.T) {
	conf := []string{
		"0.0.0.1",
		"[/a.x/]0.0.0.2",
		"[/*.a.x/]0.0.0.3",
		"[/*.b.a.x/]0.0.0.5",
	}

	uconf, err := ParseUpstreamsConfig(conf, nil)
	require.NoError(t, err)

	testCases := []struct {
		name string
		in   string
		want []string
	}{{
		name: "specified",
		in:   "a.x.",
		want: []string{"0.0.0.2:53"},
	}, {
		name: "wildcard",
		in:   "c.a.x.",
		want: []string{"0.0.0.3:53"},
	}, {
		name: "sub_spec_ignore",
		in:   "b.a.x.",
		want: []string{"0.0.0.3:53"},
	}, {
		name: "sub_spec_wildcard",
		in:   "d.b.a.x.",
		want: []string{"0.0.0.5:53"},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			assertUpstreamsForDomain(t, uconf, tc.in, tc.want)
		})
	}
}

func TestGetUpstreamsForDomain_default_wildcards(t *testing.T) {
	conf := []string{
		"127.0.0.1:5301",
		"[/example.org/]127.0.0.1:5302",
		"[/*.example.org/]127.0.0.1:5303",
		"[/www.example.org/]127.0.0.1:5304",
		"[/*.www.example.org/]#",
	}

	uconf, err := ParseUpstreamsConfig(conf, nil)
	require.NoError(t, err)

	testCases := []struct {
		name string
		in   string
		want []string
	}{{
		name: "domain",
		in:   "example.org.",
		want: []string{"127.0.0.1:5302"},
	}, {
		name: "sub_wildcard",
		in:   "sub.example.org.",
		want: []string{"127.0.0.1:5303"},
	}, {
		name: "spec_sub",
		in:   "www.example.org.",
		want: []string{"127.0.0.1:5304"},
	}, {
		name: "def_wildcard",
		in:   "abc.www.example.org.",
		want: []string{"127.0.0.1:5301"},
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			assertUpstreamsForDomain(t, uconf, tc.in, tc.want)
		})
	}
}

func BenchmarkGetUpstreamsForDomain(b *testing.B) {
	upstreams := []string{
		"[/google.com/local/]4.3.2.1",
		"[/www.google.com//]1.2.3.4",
		"[/maps.google.com/]#",
		"[/www.google.com/]tls://1.1.1.1",
	}

	config, _ := ParseUpstreamsConfig(
		upstreams,
		&upstream.Options{
			InsecureSkipVerify: false,
			Bootstrap:          []string{},
			Timeout:            1 * time.Second,
		},
	)

	for i := 0; i < b.N; i++ {
		assertUpstreamsForDomain(b, config, "www.google.com.", []string{"1.2.3.4:53", "tls://1.1.1.1:853"})
		assertUpstreamsForDomain(b, config, "www2.google.com.", []string{"4.3.2.1:53"})
		assertUpstreamsForDomain(b, config, "internal.local.", []string{"4.3.2.1:53"})
		assertUpstreamsForDomain(b, config, "google.", []string{"1.2.3.4:53"})
		assertUpstreamsForDomain(b, config, "maps.google.com.", []string{})
	}
}

// assertUpstreamsForDomain checks the addresses of the specified domain
// upstreams and their number.
func assertUpstreamsForDomain(t testing.TB, config *UpstreamConfig, domain string, address []string) {
	t.Helper()

	u := config.getUpstreamsForDomain(domain)
	require.Len(t, u, len(address))

	for i, up := range u {
		assert.Equalf(t, address[i], up.Address(), "bad upstream at index %d", i)
	}
}
07070100000066000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001A00000000dnsproxy-0.55.0/proxyutil07070100000067000081A4000000000000000000000001650C59210000025B000000000000000000000000000000000000002100000000dnsproxy-0.55.0/proxyutil/dns.go// Package proxyutil contains helper functions that are used in all other
// dnsproxy packages.
package proxyutil

import (
	"encoding/binary"
	"net"

	"github.com/miekg/dns"
)

// AddPrefix adds a 2-byte prefix with the DNS message length.
func AddPrefix(b []byte) (m []byte) {
	m = make([]byte, 2+len(b))
	binary.BigEndian.PutUint16(m, uint16(len(b)))
	copy(m[2:], b)

	return m
}

// IPFromRR returns the IP address from rr if any.
func IPFromRR(rr dns.RR) (ip net.IP) {
	switch rr := rr.(type) {
	case *dns.A:
		ip = rr.A.To4()
	case *dns.AAAA:
		ip = rr.AAAA
	default:
		// Go on.
	}

	return ip
}
07070100000068000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001800000000dnsproxy-0.55.0/scripts07070100000069000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001E00000000dnsproxy-0.55.0/scripts/hooks0707010000006A000081ED000000000000000000000001650C592100000867000000000000000000000000000000000000002900000000dnsproxy-0.55.0/scripts/hooks/pre-commit#!/bin/sh

set -e -f -u

# This comment is used to simplify checking local copies of the script.
# Bump this number every time a significant change is made to this
# script.
#
# AdGuard-Project-Version: 1

# TODO(a.garipov): Add pre-merge-commit.

# Only show interactive prompts if there a terminal is attached to
# stdout.  While this technically doesn't guarantee that reading from
# /dev/tty works, this should work reasonably well on all of our
# supported development systems and in most terminal emulators.
is_tty='0'
if [ -t '1' ]
then
	is_tty='1'
fi
readonly is_tty

# prompt is a helper that prompts the user for interactive input if that
# can be done.  If there is no terminal attached, it sleeps for two
# seconds, giving the programmer some time to react, and returns with
# a zero exit code.
prompt() {
	if [ "$is_tty" -eq '0' ]
	then
		sleep 2

		return 0
	fi

	while true
	do
		printf 'commit anyway? y/[n]: '
		read -r ans < /dev/tty

		case "$ans"
		in
		('y'|'Y')
			break
			;;
		(''|'n'|'N')
			exit 1
			;;
		(*)
			continue
			;;
		esac
	done
}

# Warn the programmer about unstaged changes and untracked files, but do
# not fail the commit, because those changes might be temporary or for
# a different branch.
awk_prog='substr($2, 2, 1) != "." { print $9; } $1 == "?" { print $2; }'
readonly awk_prog

unstaged="$( git status --porcelain=2 | awk "$awk_prog" )"
readonly unstaged

if [ "$unstaged" != "" ]
then
	printf 'WARNING: you have unstaged changes:\n\n%s\n\n' "$unstaged"
	prompt
fi

# Warn the programmer about temporary todos, but do not fail the commit,
# because the commit could be in a temporary branch.
temp_todos="$( git grep -e 'TODO.*!!' -- ':!scripts/hooks/pre-commit' || : )"
readonly temp_todos

if [ "$temp_todos" != "" ]
then
	printf 'WARNING: you have temporary todos:\n\n%s\n\n' "$temp_todos"
	prompt
fi

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$( git diff --cached --name-only -- '*.go' '*.mod' '*.sh' 'Makefile' )" ]
then
	make VERBOSE="$verbose" go-os-check go-lint go-test
fi

if [ "$( git diff --cached --name-only -- '*.md' '*.yaml' '*.yml' )" ]
then
	make VERBOSE="$verbose" txt-lint
fi
0707010000006B000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001D00000000dnsproxy-0.55.0/scripts/make0707010000006C000081A4000000000000000000000001650C592100000BC4000000000000000000000000000000000000002D00000000dnsproxy-0.55.0/scripts/make/build-docker.sh#!/bin/sh

verbose="${VERBOSE:-0}"

if [ "$verbose" -gt '0' ]
then
	set -x
	debug_flags='--debug=1'
else
	set +x
	debug_flags='--debug=0'
fi
readonly debug_flags

set -e -f -u

# Require these to be set.
commit="${REVISION:?please set REVISION}"
dist_dir="${DIST_DIR:?please set DIST_DIR}"
version="${VERSION:?please set VERSION}"
readonly commit dist_dir version

# Allow users to use sudo.
sudo_cmd="${SUDO:-}"
readonly sudo_cmd

docker_platforms="\
linux/386,\
linux/amd64,\
linux/arm/v6,\
linux/arm/v7,\
linux/arm64,\
linux/ppc64le"
readonly docker_platforms

build_date="$( date -u +'%Y-%m-%dT%H:%M:%SZ' )"
readonly build_date

# Set DOCKER_IMAGE_NAME to 'adguard/dnsproxy' if you want (and are allowed)
# to push to DockerHub.
docker_image_name="${DOCKER_IMAGE_NAME:-dnsproxy-dev}"
readonly docker_image_name

# Set DOCKER_OUTPUT to 'type=image,name=adguard/dnsproxy,push=true' if you
# want (and are allowed) to push to DockerHub.
#
# If you want to inspect the resulting image using commands like "docker image
# ls", change type to docker and also set docker_platforms to a single platform.
#
# See https://github.com/docker/buildx/issues/166.
docker_output="${DOCKER_OUTPUT:-type=image,name=${docker_image_name},push=false}"
readonly docker_output

docker_version_tag="--tag=${docker_image_name}:${version}"
docker_channel_tag="--tag=${docker_image_name}:latest"

# If version is set to 'dev' or empty, only set the version tag and avoid
# polluting the "latest" tag.
if [ "${version:-}" = 'dev' ] || [ "${version:-}" = '' ]
then
  docker_channel_tag=""
fi

readonly docker_version_tag docker_channel_tag

# Copy the binaries into a new directory under new names, so that it's easier to
# COPY them later.  DO NOT remove the trailing underscores.  See file
# docker/Dockerfile.
dist_docker="${dist_dir}/docker"
readonly dist_docker

mkdir -p "$dist_docker"
cp "${dist_dir}/linux-386/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_386_"
cp "${dist_dir}/linux-amd64/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_amd64_"
cp "${dist_dir}/linux-arm64/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_arm64_"
cp "${dist_dir}/linux-arm6/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_arm_v6"
cp "${dist_dir}/linux-arm7/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_arm_v7"
cp "${dist_dir}/linux-ppc64le/dnsproxy"\
	"${dist_docker}/dnsproxy_linux_ppc64le_"

# Prepare the default configuration for the Docker image.
cp ./config.yaml.dist "${dist_docker}/config.yaml"

# Don't use quotes with $docker_version_tag and $docker_channel_tag, because we
# want word splitting and or an empty space if tags are empty.
#
# TODO(a.garipov): Once flag --tag of docker buildx build supports commas, use
# them instead.
$sudo_cmd docker\
	"$debug_flags"\
	buildx build\
	--build-arg BUILD_DATE="$build_date"\
	--build-arg DIST_DIR="$dist_dir"\
	--build-arg VCS_REF="$commit"\
	--build-arg VERSION="$version"\
	--output "$docker_output"\
	--platform "$docker_platforms"\
	$docker_version_tag\
	$docker_channel_tag\
	-f ./docker/Dockerfile\
	.
0707010000006D000081A4000000000000000000000001650C592100000CEE000000000000000000000000000000000000002E00000000dnsproxy-0.55.0/scripts/make/build-release.sh#!/bin/sh

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '2' ]
then
	env
	set -x
elif [ "$verbose" -gt '1' ]
then
	set -x
fi

set -e -f -u

log() {
	if [ "$verbose" -gt '0' ]
	then
		# Don't use quotes to get word splitting.
		echo "$1" 1>&2
	fi
}

log 'starting to build dnsproxy release'

version="${VERSION:-}"
readonly version

log "version '$version'"

dist="${DIST_DIR:-build}"
readonly dist

out="${OUT:-dnsproxy}"

log "checking tools"

for tool in tar zip
do
	if ! command -v "$tool" > /dev/null
	then
		log "tool '$tool' not found"

		exit 1
	fi
done

# Data section.  Arrange data into space-separated tables for read -r to read.
# Use 0 for missing values.

#    os  arch      arm mips
platforms="\
darwin   amd64     0   0
darwin   arm64     0   0
freebsd  386       0   0
freebsd  amd64     0   0
freebsd  arm       5   0
freebsd  arm       6   0
freebsd  arm       7   0
freebsd  arm64     0   0
linux    386       0   0
linux    amd64     0   0
linux    arm       5   0
linux    arm       6   0
linux    arm       7   0
linux    arm64     0   0
linux    mips      0   softfloat
linux    mips64    0   softfloat
linux    mips64le  0   softfloat
linux    mipsle    0   softfloat
linux    ppc64le   0   0
openbsd  amd64     0   0
openbsd  arm64     0   0
windows  386       0   0
windows  amd64     0   0
windows  arm64     0   0"
readonly platforms

build() {
	# Get the arguments.  Here and below, use the "build_" prefix for all
	# variables local to function build.
	build_dir="${dist}/${1}"\
		build_name="$1"\
		build_os="$2"\
		build_arch="$3"\
		build_arm="$4"\
		build_mips="$5"\
		;

	# Use the ".exe" filename extension if we build a Windows release.
	if [ "$build_os" = 'windows' ]
	then
		build_output="./${build_dir}/${out}.exe"
	else
		build_output="./${build_dir}/${out}"
	fi

	mkdir -p "./${build_dir}"

	# Build the binary.
	#
	# Set GOARM and GOMIPS to an empty string if $build_arm and $build_mips
	# are zero by removing the zero as if it's a prefix.
	#
	# Don't use quotes with $build_par because we want an empty space if
	# parallelism wasn't set.
	env\
		GOARCH="$build_arch"\
		GOARM="${build_arm#0}"\
		GOMIPS="${build_mips#0}"\
		GOOS="$os"\
		VERBOSE="$(( verbose - 1 ))"\
		VERSION="$version"\
		OUT="$build_output"\
		sh ./scripts/make/go-build.sh\
		;

	log "$build_output"

	# Prepare the build directory for archiving.
	cp ./LICENSE ./README.md "$build_dir"

	# Make archives.  Windows prefers ZIP archives; the rest, gzipped tarballs.
	case "$build_os"
	in
	('windows')
		build_archive="./${dist}/${out}-${build_name}-${version}.zip"
		# TODO(a.garipov): Find an option similar to the -C option of tar for
		# zip.
		( cd "${dist}" && zip -9 -q -r "../${build_archive}" "./${build_name}" )
		;;
	(*)
		build_archive="./${dist}/${out}-${build_name}-${version}.tar.gz"
		tar -C "./${dist}" -c -f - "./${build_name}" | gzip -9 - > "$build_archive"
		;;
	esac

	log "$build_archive"
}

log "starting builds"

# Go over all platforms defined in the space-separated table above, tweak the
# values where necessary, and feed to build.
echo "$platforms" | while read -r os arch arm mips
do
	case "$arch"
	in
	(arm)
		name="${os}-${arch}${arm}"
		;;
	(*)
		name="${os}-${arch}"
		;;
	esac

	build "$name" "$os" "$arch" "$arm" "$mips"
done

log "finished"
0707010000006E000081A4000000000000000000000001650C592100000B62000000000000000000000000000000000000002900000000dnsproxy-0.55.0/scripts/make/go-build.sh#!/bin/sh

# dnsproxy build script
#
# The commentary in this file is written with the assumption that the reader
# only has superficial knowledge of the POSIX shell language and alike.
# Experienced readers may find it overly verbose.

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 1

# The default verbosity level is 0.  Show every command that is run and every
# package that is processed if the caller requested verbosity level greater than
# 0.  Also show subcommands if the requested verbosity level is greater than 1.
# Otherwise, do nothing.
verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '1' ]
then
	env
	set -x
	v_flags='-v=1'
	x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=0'
else
	set +x
	v_flags='-v=0'
	x_flags='-x=0'
fi
readonly x_flags v_flags

# Exit the script if a pipeline fails (-e), prevent accidental filename
# expansion (-f), and consider undefined variables as errors (-u).
set -e -f -u

# Allow users to override the go command from environment.  For example, to
# build two releases with two different Go versions and test the difference.
go="${GO:-go}"
readonly go

# Set the build parameters unless already set.
branch="${BRANCH:-$( git rev-parse --abbrev-ref HEAD )}"
revision="${REVISION:-$( git rev-parse --short HEAD )}"
version="${VERSION:-0}"
readonly branch revision version

# Set date and time of the latest commit unless already set.
committime="${SOURCE_DATE_EPOCH:-$( git log -1 --pretty=%ct )}"
readonly committime

# Compile them in.
version_pkg='github.com/AdguardTeam/dnsproxy/internal/version'
ldflags="-s -w"
ldflags="${ldflags} -X ${version_pkg}.branch=${branch}"
ldflags="${ldflags} -X ${version_pkg}.committime=${committime}"
ldflags="${ldflags} -X ${version_pkg}.revision=${revision}"
ldflags="${ldflags} -X ${version_pkg}.version=${version}"
readonly ldflags version_pkg

# Allow users to limit the build's parallelism.
parallelism="${PARALLELISM:-}"
readonly parallelism

# Use GOFLAGS for -p, because -p=0 simply disables the build instead of leaving
# the default value.
if [ "${parallelism}" != '' ]
then
        GOFLAGS="${GOFLAGS:-} -p=${parallelism}"
fi
readonly GOFLAGS
export GOFLAGS

# Allow users to specify a different output name.
out="${OUT:-dnsproxy}"
readonly out

o_flags="-o=${out}"
readonly o_flags

# Allow users to enable the race detector.  Unfortunately, that means that cgo
# must be enabled.
if [ "${RACE:-0}" -eq '0' ]
then
	CGO_ENABLED='0'
	race_flags='--race=0'
else
	CGO_ENABLED='1'
	race_flags='--race=1'
fi
readonly CGO_ENABLED race_flags
export CGO_ENABLED

GO111MODULE='on'
export GO111MODULE

if [ "$verbose" -gt '0' ]
then
	"$go" env
fi

"$go" build\
	--ldflags="$ldflags"\
	"$race_flags"\
	--trimpath\
	"$o_flags"\
	"$v_flags"\
	"$x_flags"
0707010000006F000081A4000000000000000000000001650C5921000001D8000000000000000000000000000000000000002800000000dnsproxy-0.55.0/scripts/make/go-deps.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 1

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '1' ]
then
	env
	set -x
	x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
	set -x
	x_flags='-x=0'
else
	set +x
	x_flags='-x=0'
fi
readonly x_flags

set -e -f -u

go="${GO:-go}"
readonly go

"$go" mod download "$x_flags"
07070100000070000081A4000000000000000000000001650C592100001493000000000000000000000000000000000000002800000000dnsproxy-0.55.0/scripts/make/go-lint.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 5

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '0' ]
then
	set -x
fi

# Set $EXIT_ON_ERROR to zero to see all errors.
if [ "${EXIT_ON_ERROR:-1}" -eq '0' ]
then
	set +e
else
	set -e
fi

set -f -u



# Source the common helpers, including not_found and run_linter.
. ./scripts/make/helper.sh



# Warnings

go_version="$( "${GO:-go}" version )"
readonly go_version

go_min_version='go1.20.8'
go_version_msg="
warning: your go version (${go_version}) is different from the recommended minimal one (${go_min_version}).
if you have the version installed, please set the GO environment variable.
for example:

	export GO='${go_min_version}'
"
readonly go_min_version go_version_msg

case "$go_version"
in
('go version'*"$go_min_version"*)
	# Go on.
	;;
(*)
	echo "$go_version_msg" 1>&2
	;;
esac



# Simple analyzers

# blocklist_imports is a simple check against unwanted packages.  The following
# packages are banned:
#
#   *  Packages errors and log are replaced by our own packages in the
#      github.com/AdguardTeam/golibs module.
#
#   *  Package io/ioutil is soft-deprecated.
#
#   *  Package reflect is often an overkill, and for deep comparisons there are
#      much better functions in module github.com/google/go-cmp.  Which is
#      already our indirect dependency and which may or may not enter the stdlib
#      at some point.
#
#      See https://github.com/golang/go/issues/45200.
#
#   *  Package sort is replaced by golang.org/x/exp/slices.
#
#   *  Package unsafe is… unsafe.
#
#   *  Package golang.org/x/net/context has been moved into stdlib.
#
# Currently, the only standard exception are files generated from protobuf
# schemas, which use package reflect.  If your project needs more exceptions,
# add and document them.
#
# TODO(a.garipov): Add deprecated packages golang.org/x/exp/maps and
# golang.org/x/exp/slices once all projects switch to Go 1.21.
blocklist_imports() {
	git grep\
		-e '[[:space:]]"errors"$'\
		-e '[[:space:]]"io/ioutil"$'\
		-e '[[:space:]]"log"$'\
		-e '[[:space:]]"reflect"$'\
		-e '[[:space:]]"sort"$'\
		-e '[[:space:]]"unsafe"$'\
		-e '[[:space:]]"golang.org/x/net/context"$'\
		-n\
		-- '*.go'\
		':!*.pb.go'\
		| sed -e 's/^\([^[:space:]]\+\)\(.*\)$/\1 blocked import:\2/'\
		|| exit 0
}

# method_const is a simple check against the usage of some raw strings and
# numbers where one should use named constants.
method_const() {
	git grep -F\
		-e '"DELETE"'\
		-e '"GET"'\
		-e '"PATCH"'\
		-e '"POST"'\
		-e '"PUT"'\
		-n\
		-- '*.go'\
		| sed -e 's/^\([^[:space:]]\+\)\(.*\)$/\1 http method literal:\2/'\
		|| exit 0
}

# underscores is a simple check against Go filenames with underscores.  Add new
# build tags and OS as you go.  The main goal of this check is to discourage the
# use of filenames like client_manager.go.
underscores() {
	underscore_files="$(
		git ls-files '*_*.go'\
			| grep -F\
			-e '_darwin.go'\
			-e '_generate.go'\
			-e '_linux.go'\
			-e '_others.go'\
			-e '_plan9.go'\
			-e '_test.go'\
			-e '_unix.go'\
			-e '_windows.go'\
			-e '_dnscrypt.go'\
			-e '_https.go'\
			-e '_quic.go'\
			-e '_tcp.go'\
			-e '_udp.go'\
			-v\
			| sed -e 's/./\t\0/'
	)"
	readonly underscore_files

	if [ "$underscore_files" != '' ]
	then
		echo 'found file names with underscores:'
		echo "$underscore_files"
	fi
}

# TODO(a.garipov): Add an analyzer to look for `fallthrough`, `goto`, and `new`?



# Checks

run_linter -e blocklist_imports

run_linter -e method_const

run_linter -e underscores

run_linter -e gofumpt --extra -e -l .

# TODO(a.garipov): golint is deprecated, find a suitable replacement.

run_linter "$GO" vet ./...

run_linter govulncheck ./...

# TODO(a.garipov): Enable for all.
run_linter gocyclo --over 10\
	./internal/bootstrap/\
	./internal/netutil/\
	./internal/version/\
	./proxyutil/\
	;

run_linter gocyclo --over 20 ./main.go
run_linter gocyclo --over 18 ./fastip/
run_linter gocyclo --over 15 ./proxy/
run_linter gocyclo --over 14 ./upstream/

# TODO(a.garipov): Enable for all.
run_linter gocognit --over 10\
	./internal/bootstrap/\
	./internal/version/\
	./proxyutil/\
	;

run_linter gocognit --over 39 ./main.go
run_linter gocognit --over 33 ./proxy/
run_linter gocognit --over 32 ./fastip/
run_linter gocognit --over 24 ./upstream/
run_linter gocognit --over 14 ./internal/netutil/

run_linter ineffassign ./...

run_linter unparam ./...

git ls-files -- 'Makefile' '*.conf' '*.go' '*.mod' '*.sh' '*.yaml' '*.yml'\
	| xargs misspell --error\
	| sed -e 's/^/misspell: /'

run_linter looppointer ./...

run_linter nilness ./...

# TODO(a.garipov): Enable for all.
run_linter fieldalignment ./fastip/...
run_linter fieldalignment ./internal/...
run_linter fieldalignment ./proxyutil/...
run_linter fieldalignment ./upstream/...

run_linter -e shadow --strict ./...

run_linter gosec --quiet ./...

run_linter errcheck ./...

staticcheck_matrix='
darwin:  GOOS=darwin
freebsd: GOOS=freebsd
linux:   GOOS=linux
openbsd: GOOS=openbsd
windows: GOOS=windows
'
readonly staticcheck_matrix

echo "$staticcheck_matrix" | run_linter staticcheck --matrix ./...
07070100000071000081A4000000000000000000000001650C5921000003E8000000000000000000000000000000000000002800000000dnsproxy-0.55.0/scripts/make/go-test.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 1

verbose="${VERBOSE:-0}"
readonly verbose

# Verbosity levels:
#   0 = Don't print anything except for errors.
#   1 = Print commands, but not nested commands.
#   2 = Print everything.
if [ "$verbose" -gt '1' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=0'
else
	set +x
	v_flags='-v=0'
	x_flags='-x=0'
fi
readonly v_flags x_flags

set -e -f -u

if [ "${RACE:-1}" -eq '0' ]
then
	race_flags='--race=0'
else
	race_flags='--race=1'
fi
readonly race_flags

go="${GO:-go}"
count_flags='--count=1'
shuffle_flags='--shuffle=on'
timeout_flags="${TIMEOUT_FLAGS:---timeout=2m}"
readonly go count_flags shuffle_flags timeout_flags

"$go" test\
	"$count_flags"\
	"$race_flags"\
	"$shuffle_flags"\
	"$timeout_flags"\
	"$v_flags"\
	"$x_flags"\
	./...
07070100000072000081A4000000000000000000000001650C592100000766000000000000000000000000000000000000002900000000dnsproxy-0.55.0/scripts/make/go-tools.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a significant change is made to this script.
#
# AdGuard-Project-Version: 3

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '1' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=1'
elif [ "$verbose" -gt '0' ]
then
	set -x
	v_flags='-v=1'
	x_flags='-x=0'
else
	set +x
	v_flags='-v=0'
	x_flags='-x=0'
fi
readonly v_flags x_flags

set -e -f -u

go="${GO:-go}"
readonly go

# Remove only the actual binaries in the bin/ directory, as developers may add
# their own scripts there.  Most commonly, a script named “go” for tools that
# call the go binary and need a particular version.
rm -f\
	bin/errcheck\
	bin/fieldalignment\
	bin/gocognit\
	bin/gocyclo\
	bin/gofumpt\
	bin/gosec\
	bin/govulncheck\
	bin/ineffassign\
	bin/looppointer\
	bin/misspell\
	bin/nilness\
	bin/shadow\
	bin/staticcheck\
	bin/unparam\
	;

# Reset GOARCH and GOOS to make sure we install the tools for the native
# architecture even when we're cross-compiling the main binary, and also to
# prevent the "cannot install cross-compiled binaries when GOBIN is set" error.
env\
	GOARCH=""\
	GOBIN="${PWD}/bin"\
	GOOS=""\
	GOWORK='off'\
	"$go" install\
	--modfile=./internal/tools/go.mod\
	"$v_flags"\
	"$x_flags"\
	github.com/fzipp/gocyclo/cmd/gocyclo\
	github.com/golangci/misspell/cmd/misspell\
	github.com/gordonklaus/ineffassign\
	github.com/kisielk/errcheck\
	github.com/kyoh86/looppointer/cmd/looppointer\
	github.com/securego/gosec/v2/cmd/gosec\
	github.com/uudashr/gocognit/cmd/gocognit\
	golang.org/x/tools/go/analysis/passes/fieldalignment/cmd/fieldalignment\
	golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness\
	golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow\
	golang.org/x/vuln/cmd/govulncheck\
	honnef.co/go/tools/cmd/staticcheck\
	mvdan.cc/gofumpt\
	mvdan.cc/unparam\
	;
07070100000073000081A4000000000000000000000001650C59210000064C000000000000000000000000000000000000002700000000dnsproxy-0.55.0/scripts/make/helper.sh#!/bin/sh

# Common script helpers
#
# This file contains common script helpers.  It should be sourced in scripts
# right after the initial environment processing.

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 2



# Deferred helpers

not_found_msg='
looks like a binary not found error.
make sure you have installed the linter binaries using:

	$ make go-tools
'
readonly not_found_msg

not_found() {
	if [ "$?" -eq '127' ]
	then
		# Code 127 is the exit status a shell uses when a command or a file is
		# not found, according to the Bash Hackers wiki.
		#
		# See https://wiki.bash-hackers.org/dict/terms/exit_status.
		echo "$not_found_msg" 1>&2
	fi
}
trap not_found EXIT



# Helpers

# run_linter runs the given linter with two additions:
#
# 1.  If the first argument is "-e", run_linter exits with a nonzero exit code
#     if there is anything in the command's combined output.
#
# 2.  In any case, run_linter adds the program's name to its combined output.
run_linter() (
	set +e

	if [ "$VERBOSE" -lt '2' ]
	then
		set +x
	fi

	cmd="${1:?run_linter: provide a command}"
	shift

	exit_on_output='0'
	if [ "$cmd" = '-e' ]
	then
		exit_on_output='1'
		cmd="${1:?run_linter: provide a command}"
		shift
	fi

	readonly cmd

	output="$( "$cmd" "$@" )"
	exitcode="$?"

	readonly output

	if [ "$output" != '' ]
	then
		echo "$output" | sed -e "s/^/${cmd}: /"

		if [ "$exitcode" -eq '0' ] && [ "$exit_on_output" -eq '1' ]
		then
			exitcode='1'
		fi
	fi

	return "$exitcode"
)
07070100000074000081A4000000000000000000000001650C592100000687000000000000000000000000000000000000002900000000dnsproxy-0.55.0/scripts/make/txt-lint.sh#!/bin/sh

# This comment is used to simplify checking local copies of the script.  Bump
# this number every time a remarkable change is made to this script.
#
# AdGuard-Project-Version: 4

verbose="${VERBOSE:-0}"
readonly verbose

if [ "$verbose" -gt '0' ]
then
	set -x
fi

# Set $EXIT_ON_ERROR to zero to see all errors.
if [ "${EXIT_ON_ERROR:-1}" -eq '0' ]
then
	set +e
else
	set -e
fi

# We don't need glob expansions and we want to see errors about unset variables.
set -f -u

# Source the common helpers, including not_found.
. ./scripts/make/helper.sh

# Simple analyzers

# trailing_newlines is a simple check that makes sure that all plain-text files
# have a trailing newlines to make sure that all tools work correctly with them.
trailing_newlines() {
	nl="$( printf "\n" )"
	readonly nl

	# NOTE: Adjust for your project.
	# TODO(d.kolyshev): Remove "vendor" after moving out from vendor dir usage.
	git ls-files\
		':!vendor/*'\
		| while read -r f
		do
			if [ "$( tail -c -1 "$f" )" != "$nl" ]
			then
				printf '%s: must have a trailing newline\n' "$f"
			fi
		done
}

# trailing_whitespace is a simple check that makes sure that there are no
# trailing whitespace in plain-text files.
trailing_whitespace() {
	# NOTE: Adjust for your project.
	# TODO(d.kolyshev): Remove "vendor" after moving out from vendor dir usage.
	git ls-files\
		':!vendor/*'\
		| while read -r f
		do
			grep -e '[[:space:]]$' -n -- "$f"\
				| sed -e "s:^:${f}\::" -e 's/ \+$/>>>&<<</'
		done
}

run_linter -e trailing_newlines

run_linter -e trailing_whitespace

git ls-files -- '*.conf' '*.md' '*.txt' '*.yaml' '*.yml'\
	| xargs misspell --error\
	| sed -e 's/^/misspell: /'
07070100000075000041ED000000000000000000000002650C592100000000000000000000000000000000000000000000001900000000dnsproxy-0.55.0/upstream07070100000076000081A4000000000000000000000001650C59210000104F000000000000000000000000000000000000002500000000dnsproxy-0.55.0/upstream/dnscrypt.gopackage upstream

import (
	"fmt"
	"io"
	"net/url"
	"os"
	"sync"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/miekg/dns"
)

// dnsCrypt implements the [Upstream] interface for the DNSCrypt protocol.
type dnsCrypt struct {
	// mu protects client and serverInfo.
	mu *sync.RWMutex

	// client stores the DNSCrypt client properties.
	client *dnscrypt.Client

	// resolverInfo stores the DNSCrypt server properties.
	resolverInfo *dnscrypt.ResolverInfo

	// addr is the DNSCrypt server URL.
	addr *url.URL

	// verifyCert is a callback that verifies the resolver's certificate.
	verifyCert func(cert *dnscrypt.Cert) (err error)

	// timeout is the timeout for the DNS requests.
	timeout time.Duration
}

// newDNSCrypt returns a new DNSCrypt Upstream.
func newDNSCrypt(addr *url.URL, opts *Options) (u *dnsCrypt) {
	return &dnsCrypt{
		mu:         &sync.RWMutex{},
		addr:       addr,
		verifyCert: opts.VerifyDNSCryptCertificate,
		timeout:    opts.Timeout,
	}
}

// type check
var _ Upstream = (*dnsCrypt)(nil)

// Address implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Address() string { return p.addr.String() }

// Exchange implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
	resp, err = p.exchangeDNSCrypt(m)
	if errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, io.EOF) {
		// If request times out, it is possible that the server configuration
		// has been changed.  It is safe to assume that the key was rotated, see
		// https://dnscrypt.pl/2017/02/26/how-key-rotation-is-automated.
		// Re-fetch the server certificate info for new requests to not fail.
		_, _, err = p.resetClient()
		if err != nil {
			return nil, err
		}

		return p.exchangeDNSCrypt(m)
	}

	return resp, err
}

// Close implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Close() (err error) {
	return nil
}

// exchangeDNSCrypt attempts to send the DNS query and returns the response.
func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (resp *dns.Msg, err error) {
	var client *dnscrypt.Client
	var resolverInfo *dnscrypt.ResolverInfo
	func() {
		p.mu.RLock()
		defer p.mu.RUnlock()

		client, resolverInfo = p.client, p.resolverInfo
	}()

	// Check the client and server info are set and the certificate is not
	// expired, since any of these cases require a client reset.
	//
	// TODO(ameshkov):  Consider using [time.Time] for [dnscrypt.Cert.NotAfter].
	switch {
	case
		client == nil,
		resolverInfo == nil,
		resolverInfo.ResolverCert.NotAfter < uint32(time.Now().Unix()):
		client, resolverInfo, err = p.resetClient()
		if err != nil {
			// Don't wrap the error, because it's informative enough as is.
			return nil, err
		}
	default:
		// Go on.
	}

	resp, err = client.Exchange(m, resolverInfo)
	if resp != nil && resp.Truncated {
		q := &m.Question[0]
		log.Debug("dnscrypt %s: received truncated, falling back to tcp with %s", p.addr, q)

		tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: networkTCP}
		resp, err = tcpClient.Exchange(m, resolverInfo)
	}
	if err == nil && resp != nil && resp.Id != m.Id {
		err = dns.ErrId
	}

	return resp, err
}

// resetClient renews the DNSCrypt client and server properties and also sets
// those to nil on fail.
func (p *dnsCrypt) resetClient() (client *dnscrypt.Client, ri *dnscrypt.ResolverInfo, err error) {
	addr := p.Address()

	defer func() {
		p.mu.Lock()
		defer p.mu.Unlock()

		p.client, p.resolverInfo = client, ri
	}()

	// Use UDP for DNSCrypt upstreams by default.
	client = &dnscrypt.Client{Timeout: p.timeout, Net: networkUDP}
	ri, err = client.Dial(addr)
	if err != nil {
		// Trigger client and server info renewal on the next request.
		client, ri = nil, nil
		err = fmt.Errorf("fetching certificate info from %s: %w", addr, err)
	} else if p.verifyCert != nil {
		err = p.verifyCert(ri.ResolverCert)
		if err != nil {
			// Trigger client and server info renewal on the next request.
			client, ri = nil, nil
			err = fmt.Errorf("verifying certificate info from %s: %w", addr, err)
		}
	}

	return client, ri, err
}
07070100000077000081A4000000000000000000000001650C592100001850000000000000000000000000000000000000002A00000000dnsproxy-0.55.0/upstream/dnscrypt_test.gopackage upstream

import (
	"context"
	"net"
	"os"
	"strings"
	"sync/atomic"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/ameshkov/dnsstamps"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// Helpers

// dnsCryptHandlerFunc is a function-based implementation of the
// [dnscrypt.Handler] interface.
type dnsCryptHandlerFunc func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error)

// ServeDNS implements the [dnscrypt.Handler] interface for DNSCryptHandlerFunc.
func (f dnsCryptHandlerFunc) ServeDNS(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
	return f(w, r)
}

// startTestDNSCryptServer starts a test DNSCrypt server with the specified
// resolver config and handler.
func startTestDNSCryptServer(
	t testing.TB,
	rc dnscrypt.ResolverConfig,
	h dnscrypt.Handler,
) (stamp dnsstamps.ServerStamp) {
	t.Helper()

	cert, err := rc.CreateCert()
	require.NoError(t, err)

	s := &dnscrypt.Server{
		ProviderName: rc.ProviderName,
		ResolverCert: cert,
		Handler:      h,
	}
	testutil.CleanupAndRequireSuccess(t, func() (err error) {
		ctx, cancel := context.WithTimeout(context.Background(), timeout)
		defer cancel()

		return s.Shutdown(ctx)
	})

	localhost := netutil.IPv4Localhost().AsSlice()

	// Prepare TCP listener.
	tcpAddr := &net.TCPAddr{IP: localhost, Port: 0}
	tcpConn, err := net.ListenTCP("tcp", tcpAddr)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, tcpConn.Close)

	// Prepare UDP listener on the same port.
	port := testutil.RequireTypeAssert[*net.TCPAddr](t, tcpConn.Addr()).Port
	udpAddr := &net.UDPAddr{IP: localhost, Port: port}
	udpConn, err := net.ListenUDP("udp", udpAddr)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, udpConn.Close)

	// Start the server.
	go func() {
		udpErr := s.ServeUDP(udpConn)
		require.ErrorIs(testutil.PanicT{}, udpErr, net.ErrClosed)
	}()

	go func() {
		tcpErr := s.ServeTCP(tcpConn)
		require.NoError(testutil.PanicT{}, tcpErr)
	}()

	stamp, err = rc.CreateStamp(udpConn.LocalAddr().String())
	require.NoError(t, err)

	_, err = net.Dial("tcp", udpAddr.String())
	require.NoError(t, err)

	return stamp
}

// Tests

func TestUpstreamDNSCrypt(t *testing.T) {
	// AdGuard DNS (DNSCrypt)
	address := "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20"
	u, err := AddressToUpstream(address, &Options{Timeout: dialTimeout})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Test that it responds properly
	for i := 0; i < 10; i++ {
		checkUpstream(t, u, address)
	}
}

func TestDNSCrypt_Exchange_truncated(t *testing.T) {
	// Prepare the test DNSCrypt server config
	rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
	require.NoError(t, err)

	var udpNum, tcpNum atomic.Uint32
	h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
		if w.RemoteAddr().Network() == networkUDP {
			udpNum.Add(1)
		} else {
			tcpNum.Add(1)
		}

		res := (&dns.Msg{}).SetReply(r)
		answer := &dns.TXT{
			Hdr: dns.RR_Header{
				Name:   r.Question[0].Name,
				Rrtype: dns.TypeTXT,
				Ttl:    300,
				Class:  dns.ClassINET,
			},
		}
		res.Answer = append(res.Answer, answer)

		veryLongString := strings.Repeat("VERY LONG STRING", 7)
		for i := 0; i < 50; i++ {
			answer.Txt = append(answer.Txt, veryLongString)
		}

		return w.WriteMsg(res)
	})
	srvStamp := startTestDNSCryptServer(t, rc, h)

	u, err := AddressToUpstream(srvStamp.String(), &Options{Timeout: timeout})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	req := (&dns.Msg{}).SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)

	// Check that response is not truncated (even though it's huge).
	res, err := u.Exchange(req)
	require.NoError(t, err)

	assert.False(t, res.Truncated)
	assert.Equal(t, 1, int(udpNum.Load()))
	assert.Equal(t, 1, int(tcpNum.Load()))
}

func TestDNSCrypt_Exchange_deadline(t *testing.T) {
	// Prepare the test DNSCrypt server config
	rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
	require.NoError(t, err)

	h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
		return nil
	})

	srvStamp := startTestDNSCryptServer(t, rc, h)

	// Use a shorter timeout to speed up the test.
	u, err := AddressToUpstream(srvStamp.String(), &Options{Timeout: 100 * time.Millisecond})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	req := (&dns.Msg{}).SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)

	res, err := u.Exchange(req)
	require.ErrorIs(t, err, os.ErrDeadlineExceeded)

	assert.Nil(t, res)
}

func TestDNSCrypt_Exchange_dialFail(t *testing.T) {
	// Prepare the test DNSCrypt server config
	rc, err := dnscrypt.GenerateResolverConfig("example.org", nil)
	require.NoError(t, err)

	h := dnsCryptHandlerFunc(func(w dnscrypt.ResponseWriter, r *dns.Msg) (err error) {
		return nil
	})

	req := (&dns.Msg{}).SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)
	var u Upstream

	require.True(t, t.Run("run_and_shutdown", func(t *testing.T) {
		srvStamp := startTestDNSCryptServer(t, rc, h)

		// Use a shorter timeout to speed up the test.
		u, err = AddressToUpstream(srvStamp.String(), &Options{Timeout: 100 * time.Millisecond})
		require.NoError(t, err)
	}))

	require.True(t, t.Run("dial_fail", func(t *testing.T) {
		testutil.CleanupAndRequireSuccess(t, u.Close)

		var res *dns.Msg
		res, err = u.Exchange(req)
		require.Error(t, err)

		assert.Nil(t, res)
	}))

	t.Run("restart", func(t *testing.T) {
		const validationErr errors.Error = "bad cert"

		srvStamp := startTestDNSCryptServer(t, rc, h)

		// Use a shorter timeout to speed up the test.
		u, err = AddressToUpstream(srvStamp.String(), &Options{
			Timeout: 100 * time.Millisecond,
			VerifyDNSCryptCertificate: func(cert *dnscrypt.Cert) (err error) {
				return validationErr
			},
		})
		require.NoError(t, err)

		var res *dns.Msg
		res, err = u.Exchange(req)
		require.ErrorIs(t, err, validationErr)

		assert.Nil(t, res)
	})
}
07070100000078000081A4000000000000000000000001650C592100005076000000000000000000000000000000000000002000000000dnsproxy-0.55.0/upstream/doh.gopackage upstream

import (
	"context"
	"crypto/tls"
	"encoding/base64"
	"fmt"
	"io"
	"net"
	"net/http"
	"net/url"
	"runtime"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
	"golang.org/x/net/http2"
)

// Values to configure HTTP and HTTP/2 transport.
const (
	// transportDefaultReadIdleTimeout is the default timeout for pinging
	// idle connections in HTTP/2 transport.
	transportDefaultReadIdleTimeout = 30 * time.Second

	// transportDefaultIdleConnTimeout is the default timeout for idle
	// connections in HTTP transport.
	transportDefaultIdleConnTimeout = 5 * time.Minute

	// dohMaxConnsPerHost controls the maximum number of connections for
	// each host.  Note, that setting it to 1 may cause issues with Go's http
	// implementation, see https://github.com/AdguardTeam/dnsproxy/issues/278.
	dohMaxConnsPerHost = 2

	// dohMaxIdleConns controls the maximum number of connections being idle
	// at the same time.
	dohMaxIdleConns = 2
)

// dnsOverHTTPS is a struct that implements the Upstream interface for the
// DNS-over-HTTPS protocol.
type dnsOverHTTPS struct {
	// getDialer either returns an initialized dial handler or creates a new
	// one.
	getDialer DialerInitializer

	// addr is the DNS-over-HTTPS server URL.
	addr *url.URL

	// tlsConf is the configuration of TLS.
	tlsConf *tls.Config

	// The Client's Transport typically has internal state (cached TCP
	// connections), so Clients should be reused instead of created as needed.
	// Clients are safe for concurrent use by multiple goroutines.
	client *http.Client

	// quicConfig is the QUIC configuration that is used if HTTP/3 is enabled
	// for this upstream.
	quicConfig *quic.Config

	// clientMu protects client.
	clientMu sync.Mutex

	// quicConfMu protects quicConfig.
	quicConfMu sync.Mutex

	// timeout is used in HTTP client and for H3 probes.
	timeout time.Duration
}

// type check
var _ Upstream = (*dnsOverHTTPS)(nil)

// newDoH returns the DNS-over-HTTPS Upstream.
func newDoH(addr *url.URL, opts *Options) (u Upstream, err error) {
	addPort(addr, defaultPortDoH)

	var httpVersions []HTTPVersion
	if addr.Scheme == "h3" {
		addr.Scheme = "https"
		httpVersions = []HTTPVersion{HTTPVersion3}
	} else if httpVersions = opts.HTTPVersions; len(opts.HTTPVersions) == 0 {
		httpVersions = DefaultHTTPVersions
	}

	getDialer, err := newDialerInitializer(addr, opts)
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return nil, err
	}

	ups := &dnsOverHTTPS{
		getDialer: getDialer,
		addr:      addr,
		quicConfig: &quic.Config{
			KeepAlivePeriod: QUICKeepAlivePeriod,
			TokenStore:      newQUICTokenStore(),
			Tracer:          opts.QUICTracer,
		},
		// #nosec G402 -- TLS certificate verification could be disabled by
		// configuration.
		tlsConf: &tls.Config{
			ServerName:   addr.Hostname(),
			RootCAs:      opts.RootCAs,
			CipherSuites: opts.CipherSuites,
			// Use the default capacity for the LRU cache.  It may be useful to
			// store several caches since the user may be routed to different
			// servers in case there's load balancing on the server-side.
			ClientSessionCache:    tls.NewLRUClientSessionCache(0),
			MinVersion:            tls.VersionTLS12,
			InsecureSkipVerify:    opts.InsecureSkipVerify,
			VerifyPeerCertificate: opts.VerifyServerCertificate,
			VerifyConnection:      opts.VerifyConnection,
		},
		timeout: opts.Timeout,
	}
	for _, v := range httpVersions {
		ups.tlsConf.NextProtos = append(ups.tlsConf.NextProtos, string(v))
	}

	runtime.SetFinalizer(ups, (*dnsOverHTTPS).Close)

	return ups, nil
}

// Address implements the [Upstream] interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Address() string { return p.addr.String() }

// Exchange implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
	// In order to maximize HTTP cache friendliness, DoH clients using media
	// formats that include the ID field from the DNS message header, such
	// as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS
	// request.
	//
	// See https://www.rfc-editor.org/rfc/rfc8484.html.
	id := m.Id
	m.Id = 0
	defer func() {
		// Restore the original ID to not break compatibility with proxies.
		m.Id = id
		if resp != nil {
			resp.Id = id
		}
	}()

	// Check if there was already an active client before sending the request.
	// We'll only attempt to re-connect if there was one.
	client, isCached, err := p.getClient()
	if err != nil {
		return nil, fmt.Errorf("failed to init http client: %w", err)
	}

	// Make the first attempt to send the DNS query.
	resp, err = p.exchangeHTTPS(client, m)

	// Make up to 2 attempts to re-create the HTTP client and send the request
	// again.  There are several cases (mostly, with QUIC) where this workaround
	// is necessary to make HTTP client usable.  We need to make 2 attempts in
	// the case when the connection was closed (due to inactivity for example)
	// AND the server refuses to open a 0-RTT connection.
	for i := 0; isCached && p.shouldRetry(err) && i < 2; i++ {
		client, err = p.resetClient(err)
		if err != nil {
			return nil, fmt.Errorf("failed to reset http client: %w", err)
		}

		resp, err = p.exchangeHTTPS(client, m)
	}

	if err != nil {
		// If the request failed anyway, make sure we don't use this client.
		_, resErr := p.resetClient(err)

		return nil, errors.WithDeferred(err, resErr)
	}

	return resp, err
}

// Close implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Close() (err error) {
	p.clientMu.Lock()
	defer p.clientMu.Unlock()

	runtime.SetFinalizer(p, nil)

	if p.client == nil {
		return nil
	}

	return p.closeClient(p.client)
}

// closeClient cleans up resources used by client if necessary.  Note, that at
// this point it should only be done for HTTP/3 as it may leak due to keep-alive
// connections.
func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
	if isHTTP3(client) {
		return client.Transport.(io.Closer).Close()
	}

	return nil
}

// exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient.
func (p *dnsOverHTTPS) exchangeHTTPS(client *http.Client, req *dns.Msg) (resp *dns.Msg, err error) {
	addr := p.Address()

	n := networkTCP
	if isHTTP3(client) {
		n = networkUDP
	}

	logBegin(addr, n, req)
	resp, err = p.exchangeHTTPSClient(client, req)
	logFinish(addr, n, err)

	return resp, err
}

// exchangeHTTPSClient sends the DNS query to a DoH resolver using the specified
// http.Client instance.
func (p *dnsOverHTTPS) exchangeHTTPSClient(
	client *http.Client,
	req *dns.Msg,
) (resp *dns.Msg, err error) {
	buf, err := req.Pack()
	if err != nil {
		return nil, fmt.Errorf("packing message: %w", err)
	}

	// It appears, that GET requests are more memory-efficient with Golang
	// implementation of HTTP/2.
	method := http.MethodGet
	if isHTTP3(client) {
		// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
		method = http3.MethodGet0RTT
	}

	u := url.URL{
		Scheme:   p.addr.Scheme,
		Host:     p.addr.Host,
		Path:     p.addr.Path,
		RawQuery: fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)),
	}

	httpReq, err := http.NewRequest(method, u.String(), nil)
	if err != nil {
		return nil, fmt.Errorf("creating http request to %s: %w", p.addr, err)
	}

	httpReq.Header.Set("Accept", "application/dns-message")
	httpReq.Header.Set("User-Agent", "")

	httpResp, err := client.Do(httpReq)
	if err != nil {
		return nil, fmt.Errorf("requesting %s: %w", p.addr, err)
	}
	defer log.OnCloserError(httpResp.Body, log.DEBUG)

	body, err := io.ReadAll(httpResp.Body)
	if err != nil {
		return nil, fmt.Errorf("reading %s: %w", p.addr, err)
	}

	if httpResp.StatusCode != http.StatusOK {
		return nil,
			fmt.Errorf(
				"expected status %d, got %d from %s",
				http.StatusOK,
				httpResp.StatusCode,
				p.addr,
			)
	}

	resp = &dns.Msg{}
	err = resp.Unpack(body)
	if err != nil {
		return nil, fmt.Errorf(
			"unpacking response from %s: body is %s: %w",
			p.addr,
			body,
			err,
		)
	}

	if resp.Id != req.Id {
		err = dns.ErrId
	}

	return resp, err
}

// shouldRetry checks what error we have received and returns true if we should
// re-create the HTTP client and retry the request.
func (p *dnsOverHTTPS) shouldRetry(err error) (ok bool) {
	if err == nil {
		return false
	}

	var netErr net.Error
	if errors.As(err, &netErr) && netErr.Timeout() {
		// If this is a timeout error, trying to forcibly re-create the HTTP
		// client instance.  This is an attempt to fix an issue with DoH client
		// stalling after a network change.
		//
		// See https://github.com/AdguardTeam/AdGuardHome/issues/3217.
		return true
	}

	if isQUICRetryError(err) {
		return true
	}

	return false
}

// resetClient triggers re-creation of the *http.Client that is used by this
// upstream.  This method accepts the error that caused resetting client as
// depending on the error we may also reset the QUIC config.
func (p *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) {
	p.clientMu.Lock()
	defer p.clientMu.Unlock()

	if errors.Is(resetErr, quic.Err0RTTRejected) {
		// Reset the TokenStore only if 0-RTT was rejected.
		p.resetQUICConfig()
	}

	oldClient := p.client
	if oldClient != nil {
		closeErr := p.closeClient(oldClient)
		if closeErr != nil {
			log.Info("warning: failed to close the old http client: %v", closeErr)
		}
	}

	log.Debug("re-creating the http client due to %v", resetErr)
	p.client, err = p.createClient()

	return p.client, err
}

// getQUICConfig returns the QUIC config in a thread-safe manner.  Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverHTTPS) getQUICConfig() (c *quic.Config) {
	p.quicConfMu.Lock()
	defer p.quicConfMu.Unlock()

	return p.quicConfig
}

// resetQUICConfig Re-create the token store to make sure we're not trying to
// use invalid for 0-RTT.
func (p *dnsOverHTTPS) resetQUICConfig() {
	p.quicConfMu.Lock()
	defer p.quicConfMu.Unlock()

	p.quicConfig = p.quicConfig.Clone()
	p.quicConfig.TokenStore = newQUICTokenStore()
}

// getClient gets or lazily initializes an HTTP client (and transport) that will
// be used for this DoH resolver.
func (p *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) {
	startTime := time.Now()

	p.clientMu.Lock()
	defer p.clientMu.Unlock()

	if p.client != nil {
		return p.client, true, nil
	}

	// Timeout can be exceeded while waiting for the lock. This happens quite
	// often on mobile devices.
	elapsed := time.Since(startTime)
	if p.timeout > 0 && elapsed > p.timeout {
		return nil, false, fmt.Errorf("timeout exceeded: %s", elapsed)
	}

	log.Debug("creating a new http client")
	p.client, err = p.createClient()

	return p.client, false, err
}

// createClient creates a new *http.Client instance.  The HTTP protocol version
// will depend on whether HTTP3 is allowed and provided by this upstream.  Note,
// that we'll attempt to establish a QUIC connection when creating the client in
// order to check whether HTTP3 is supported.
func (p *dnsOverHTTPS) createClient() (*http.Client, error) {
	transport, err := p.createTransport()
	if err != nil {
		return nil, fmt.Errorf("initializing http transport: %w", err)
	}

	client := &http.Client{
		Transport: transport,
		// TODO(ameshkov):  p.timeout may appear zero that will disable the
		// timeout for client, consider using the default.
		Timeout: p.timeout,
		Jar:     nil,
	}

	p.client = client

	return p.client, nil
}

// createTransport initializes an HTTP transport that will be used specifically
// for this DoH resolver.  This HTTP transport ensures that the HTTP requests
// will be sent exactly to the IP address got from the bootstrap resolver. Note,
// that this function will first attempt to establish a QUIC connection (if
// HTTP3 is enabled in the upstream options).  If this attempt is successful,
// it returns an HTTP3 transport, otherwise it returns the H1/H2 transport.
func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
	dialContext, err := p.getDialer()
	if err != nil {
		return nil, fmt.Errorf("bootstrapping %s: %w", p.addr, err)
	}

	// First, we attempt to create an HTTP3 transport.  If the probe QUIC
	// connection is established successfully, we'll be using HTTP3 for this
	// upstream.
	tlsConf := p.tlsConf.Clone()
	transportH3, err := p.createTransportH3(tlsConf, dialContext)
	if err == nil {
		log.Debug("using HTTP/3 for this upstream: QUIC was faster")
		return transportH3, nil
	}

	log.Debug("using HTTP/2 for this upstream: %v", err)

	if !p.supportsHTTP() {
		return nil, errors.Error("HTTP1/1 and HTTP2 are not supported by this upstream")
	}

	transport := &http.Transport{
		TLSClientConfig:    tlsConf,
		DisableCompression: true,
		DialContext:        dialContext,
		IdleConnTimeout:    transportDefaultIdleConnTimeout,
		MaxConnsPerHost:    dohMaxConnsPerHost,
		MaxIdleConns:       dohMaxIdleConns,
		// Since we have a custom DialContext, we need to use this field to make
		// golang http.Client attempt to use HTTP/2. Otherwise, it would only be
		// used when negotiated on the TLS level.
		ForceAttemptHTTP2: true,
	}

	// Explicitly configure transport to use HTTP/2.
	//
	// See https://github.com/AdguardTeam/dnsproxy/issues/11.
	var transportH2 *http2.Transport
	transportH2, err = http2.ConfigureTransports(transport)
	if err != nil {
		return nil, err
	}

	// Enable HTTP/2 pings on idle connections.
	transportH2.ReadIdleTimeout = transportDefaultReadIdleTimeout

	return transport, nil
}

// http3Transport is a wrapper over *http3.RoundTripper that tries to optimize
// its behavior.  The main thing that it does is trying to force use a single
// connection to a host instead of creating a new one all the time.  It also
// helps mitigate race issues with quic-go.
type http3Transport struct {
	baseTransport *http3.RoundTripper

	closed bool
	mu     sync.RWMutex
}

// type check
var _ http.RoundTripper = (*http3Transport)(nil)

// RoundTrip implements the http.RoundTripper interface for *http3Transport.
func (h *http3Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
	h.mu.RLock()
	defer h.mu.RUnlock()

	if h.closed {
		return nil, net.ErrClosed
	}

	// Try to use cached connection to the target host if it's available.
	resp, err = h.baseTransport.RoundTripOpt(req, http3.RoundTripOpt{OnlyCachedConn: true})

	if errors.Is(err, http3.ErrNoCachedConn) {
		// If there are no cached connection, trigger creating a new one.
		resp, err = h.baseTransport.RoundTrip(req)
	}

	return resp, err
}

// type check
var _ io.Closer = (*http3Transport)(nil)

// Close implements the io.Closer interface for *http3Transport.
func (h *http3Transport) Close() (err error) {
	h.mu.Lock()
	defer h.mu.Unlock()

	h.closed = true

	return h.baseTransport.Close()
}

// createTransportH3 tries to create an HTTP/3 transport for this upstream.
// We should be able to fall back to H1/H2 in case if HTTP/3 is unavailable or
// if it is too slow.  In order to do that, this method will run two probes
// in parallel (one for TLS, the other one for QUIC) and if QUIC is faster it
// will create the *http3.RoundTripper instance.
func (p *dnsOverHTTPS) createTransportH3(
	tlsConfig *tls.Config,
	dialContext bootstrap.DialHandler,
) (roundTripper http.RoundTripper, err error) {
	if !p.supportsH3() {
		return nil, errors.Error("HTTP3 support is not enabled")
	}

	addr, err := p.probeH3(tlsConfig, dialContext)
	if err != nil {
		return nil, err
	}

	rt := &http3.RoundTripper{
		Dial: func(
			ctx context.Context,

			// Ignore the address and always connect to the one that we got
			// from the bootstrapper.
			_ string,
			tlsCfg *tls.Config,
			cfg *quic.Config,
		) (c quic.EarlyConnection, err error) {
			c, err = quic.DialAddrEarly(ctx, addr, tlsCfg, cfg)
			return c, err
		},
		DisableCompression: true,
		TLSClientConfig:    tlsConfig,
		QuicConfig:         p.getQUICConfig(),
	}

	return &http3Transport{baseTransport: rt}, nil
}

// probeH3 runs a test to check whether QUIC is faster than TLS for this
// upstream.  If the test is successful it will return the address that we
// should use to establish the QUIC connections.
func (p *dnsOverHTTPS) probeH3(
	tlsConfig *tls.Config,
	dialContext bootstrap.DialHandler,
) (addr string, err error) {
	// We're using bootstrapped address instead of what's passed to the function
	// it does not create an actual connection, but it helps us determine
	// what IP is actually reachable (when there are v4/v6 addresses).
	rawConn, err := dialContext(context.Background(), "udp", "")
	if err != nil {
		return "", fmt.Errorf("failed to dial: %w", err)
	}
	// It's never actually used.
	_ = rawConn.Close()

	udpConn, ok := rawConn.(*net.UDPConn)
	if !ok {
		return "", fmt.Errorf("not a UDP connection to %s", p.addr)
	}

	addr = udpConn.RemoteAddr().String()

	// Avoid spending time on probing if this upstream only supports HTTP/3.
	if p.supportsH3() && !p.supportsHTTP() {
		return addr, nil
	}

	// Use a new *tls.Config with empty session cache for probe connections.
	// Surprisingly, this is really important since otherwise it invalidates
	// the existing cache.
	// TODO(ameshkov): figure out why the sessions cache invalidates here.
	probeTLSCfg := tlsConfig.Clone()
	probeTLSCfg.ClientSessionCache = nil

	// Do not expose probe connections to the callbacks that are passed to
	// the bootstrap options to avoid side-effects.
	// TODO(ameshkov): consider exposing, somehow mark that this is a probe.
	probeTLSCfg.VerifyPeerCertificate = nil
	probeTLSCfg.VerifyConnection = nil

	// Run probeQUIC and probeTLS in parallel and see which one is faster.
	chQuic := make(chan error, 1)
	chTLS := make(chan error, 1)
	go p.probeQUIC(addr, probeTLSCfg, chQuic)
	go p.probeTLS(dialContext, probeTLSCfg, chTLS)

	select {
	case quicErr := <-chQuic:
		if quicErr != nil {
			// QUIC failed, return error since HTTP3 was not preferred.
			return "", quicErr
		}

		// Return immediately, QUIC was faster.
		return addr, quicErr
	case tlsErr := <-chTLS:
		if tlsErr != nil {
			// Return immediately, TLS failed.
			log.Debug("probing TLS: %v", tlsErr)
			return addr, nil
		}

		return "", errors.Error("TLS was faster than QUIC, prefer it")
	}
}

// probeQUIC attempts to establish a QUIC connection to the specified address.
// We run probeQUIC and probeTLS in parallel and see which one is faster.
func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) {
	startTime := time.Now()

	t := p.timeout
	if t == 0 {
		t = dialTimeout
	}
	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(t))
	defer cancel()

	conn, err := quic.DialAddrEarly(ctx, addr, tlsConfig, p.getQUICConfig())
	if err != nil {
		ch <- fmt.Errorf("opening QUIC connection to %s: %w", p.addr, err)
		return
	}

	// Ignore the error since there's no way we can use it for anything useful.
	_ = conn.CloseWithError(QUICCodeNoError, "")

	ch <- nil

	elapsed := time.Since(startTime)
	log.Debug("elapsed on establishing a QUIC connection: %s", elapsed)
}

// probeTLS attempts to establish a TLS connection to the specified address. We
// run probeQUIC and probeTLS in parallel and see which one is faster.
func (p *dnsOverHTTPS) probeTLS(dialContext bootstrap.DialHandler, tlsConfig *tls.Config, ch chan error) {
	startTime := time.Now()

	conn, err := tlsDial(dialContext, tlsConfig)
	if err != nil {
		ch <- fmt.Errorf("opening TLS connection: %w", err)
		return
	}

	// Ignore the error since there's no way we can use it for anything useful.
	_ = conn.Close()

	ch <- nil

	elapsed := time.Since(startTime)
	log.Debug("elapsed on establishing a TLS connection: %s", elapsed)
}

// supportsH3 returns true if HTTP/3 is supported by this upstream.
func (p *dnsOverHTTPS) supportsH3() (ok bool) {
	for _, v := range p.tlsConf.NextProtos {
		if v == string(HTTPVersion3) {
			return true
		}
	}

	return false
}

// supportsHTTP returns true if HTTP/1.1 or HTTP2 is supported by this upstream.
func (p *dnsOverHTTPS) supportsHTTP() (ok bool) {
	for _, v := range p.tlsConf.NextProtos {
		if v == string(HTTPVersion11) || v == string(HTTPVersion2) {
			return true
		}
	}

	return false
}

// isHTTP3 checks if the *http.Client is an HTTP/3 client.
func isHTTP3(client *http.Client) (ok bool) {
	_, ok = client.Transport.(*http3Transport)

	return ok
}
07070100000079000081A4000000000000000000000001650C5921000034C4000000000000000000000000000000000000002500000000dnsproxy-0.55.0/upstream/doh_test.gopackage upstream

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/base64"
	"fmt"
	"net"
	"net/http"
	"strconv"
	"sync/atomic"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
	"github.com/stretchr/testify/require"
)

func TestUpstreamDoH(t *testing.T) {
	testCases := []struct {
		name             string
		expectedProtocol HTTPVersion
		httpVersions     []HTTPVersion
		delayHandshakeH3 time.Duration
		delayHandshakeH2 time.Duration
		http3Enabled     bool
	}{{
		name:             "http1.1_h2",
		http3Enabled:     false,
		httpVersions:     []HTTPVersion{HTTPVersion11, HTTPVersion2},
		expectedProtocol: HTTPVersion2,
	}, {
		name:             "fallback_to_http2",
		http3Enabled:     false,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		expectedProtocol: HTTPVersion2,
	}, {
		name:             "http3",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3},
		expectedProtocol: HTTPVersion3,
	}, {
		name:             "race_http3_faster",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		delayHandshakeH2: time.Second,
		expectedProtocol: HTTPVersion3,
	}, {
		name:             "race_http2_faster",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		delayHandshakeH3: time.Second,
		expectedProtocol: HTTPVersion2,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			srv := startDoHServer(t, testDoHServerOptions{
				http3Enabled:     tc.http3Enabled,
				delayHandshakeH2: tc.delayHandshakeH2,
				delayHandshakeH3: tc.delayHandshakeH3,
			})
			t.Cleanup(srv.Shutdown)

			// Create a DNS-over-HTTPS upstream.
			address := fmt.Sprintf("https://%s/dns-query", srv.addr)

			var lastState tls.ConnectionState
			opts := &Options{
				InsecureSkipVerify: true,
				HTTPVersions:       tc.httpVersions,
				VerifyConnection: func(state tls.ConnectionState) (err error) {
					if state.NegotiatedProtocol != string(tc.expectedProtocol) {
						return fmt.Errorf(
							"expected %s, got %s",
							tc.expectedProtocol,
							state.NegotiatedProtocol,
						)
					}
					lastState = state
					return nil
				},
			}
			u, err := AddressToUpstream(address, opts)
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			// Test that it responds properly.
			for i := 0; i < 10; i++ {
				checkUpstream(t, u, address)
			}

			doh := u.(*dnsOverHTTPS)

			// Trigger re-connection.
			doh.client = nil

			// Force it to establish the connection again.
			checkUpstream(t, u, address)

			// Check that TLS session was resumed properly.
			require.True(t, lastState.DidResume)
		})
	}
}

func TestUpstreamDoH_raceReconnect(t *testing.T) {
	testCases := []struct {
		name             string
		expectedProtocol HTTPVersion
		httpVersions     []HTTPVersion
		delayHandshakeH3 time.Duration
		delayHandshakeH2 time.Duration
		http3Enabled     bool
	}{{
		name:             "http1.1_h2",
		http3Enabled:     false,
		httpVersions:     []HTTPVersion{HTTPVersion11, HTTPVersion2},
		expectedProtocol: HTTPVersion2,
	}, {
		name:             "fallback_to_http2",
		http3Enabled:     false,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		expectedProtocol: HTTPVersion2,
	}, {
		name:             "http3",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3},
		expectedProtocol: HTTPVersion3,
	}, {
		name:             "race_http3_faster",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		delayHandshakeH2: time.Second,
		expectedProtocol: HTTPVersion3,
	}, {
		name:             "race_http2_faster",
		http3Enabled:     true,
		httpVersions:     []HTTPVersion{HTTPVersion3, HTTPVersion2},
		delayHandshakeH3: time.Second,
		expectedProtocol: HTTPVersion2,
	}}

	// This is a different set of tests that are supposed to be run with -race.
	// The difference is that the HTTP handler here adds additional time.Sleep
	// call.  This call would trigger the HTTP client re-connection which is
	// important to test for race conditions.
	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			const timeout = time.Millisecond * 100
			var requestsCount int32

			handlerFunc := createDoHHandlerFunc()
			mux := http.NewServeMux()
			mux.HandleFunc("/dns-query", func(w http.ResponseWriter, r *http.Request) {
				newVal := atomic.AddInt32(&requestsCount, 1)
				if newVal%10 == 0 {
					time.Sleep(timeout * 2)
				}
				handlerFunc(w, r)
			})

			srv := startDoHServer(t, testDoHServerOptions{
				http3Enabled:     tc.http3Enabled,
				delayHandshakeH2: tc.delayHandshakeH2,
				delayHandshakeH3: tc.delayHandshakeH3,
				handler:          mux,
			})
			t.Cleanup(srv.Shutdown)

			// Create a DNS-over-HTTPS upstream that will be used for the
			// race test.
			address := fmt.Sprintf("https://%s/dns-query", srv.addr)
			opts := &Options{
				InsecureSkipVerify: true,
				HTTPVersions:       tc.httpVersions,
				Timeout:            timeout,
			}
			u, err := AddressToUpstream(address, opts)
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			checkRaceCondition(u)
		})
	}
}

func TestUpstreamDoH_serverRestart(t *testing.T) {
	testCases := []struct {
		name         string
		httpVersions []HTTPVersion
	}{
		{
			name:         "http2",
			httpVersions: []HTTPVersion{HTTPVersion11, HTTPVersion2},
		},
		{
			name:         "http3",
			httpVersions: []HTTPVersion{HTTPVersion3},
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			// Run the first server instance.
			srv := startDoHServer(t, testDoHServerOptions{
				http3Enabled: true,
			})

			// Create a DNS-over-HTTPS upstream.
			address := fmt.Sprintf("https://%s/dns-query", srv.addr)
			u, err := AddressToUpstream(
				address,
				&Options{
					InsecureSkipVerify: true,
					HTTPVersions:       tc.httpVersions,
					Timeout:            time.Second,
				},
			)
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			// Test that the upstream works properly.
			checkUpstream(t, u, address)

			// Now let's restart the server on the same address.
			_, portStr, err := net.SplitHostPort(srv.addr)
			require.NoError(t, err)
			port, err := strconv.Atoi(portStr)
			require.NoError(t, err)

			// Shutdown the first server.
			srv.Shutdown()

			// Start the new one on the same port.
			srv = startDoHServer(t, testDoHServerOptions{
				http3Enabled: true,
				port:         port,
			})

			// Check that everything works after restart.
			checkUpstream(t, u, address)

			// Stop the server again.
			srv.Shutdown()

			// Now try to send a message and make sure that it returns an error.
			_, err = u.Exchange(createTestMessage())
			require.Error(t, err)

			// Start the server one more time.
			srv = startDoHServer(t, testDoHServerOptions{
				http3Enabled: true,
				port:         port,
			})
			defer srv.Shutdown()

			// Check that everything works after the second restart.
			checkUpstream(t, u, address)
		})
	}
}

func TestUpstreamDoH_0RTT(t *testing.T) {
	// Run the first server instance.
	srv := startDoHServer(t, testDoHServerOptions{
		http3Enabled: true,
	})
	t.Cleanup(srv.Shutdown)

	// Create a DNS-over-HTTPS upstream.
	tracer := &quicTracer{}
	address := fmt.Sprintf("h3://%s/dns-query", srv.addr)
	u, err := AddressToUpstream(address, &Options{
		InsecureSkipVerify: true,
		QUICTracer:         tracer.TracerForConnection,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	uh := u.(*dnsOverHTTPS)
	req := createTestMessage()

	// Trigger connection to a DoH3 server.
	resp, err := uh.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, resp)

	// Close the active connection to make sure we'll reconnect.
	func() {
		uh.clientMu.Lock()
		defer uh.clientMu.Unlock()

		err = uh.closeClient(uh.client)
		require.NoError(t, err)

		uh.client = nil
	}()

	// Trigger second connection.
	resp, err = uh.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, resp)

	// Check traced connections info.
	conns := tracer.getConnectionsInfo()
	require.Len(t, conns, 2)

	// Examine the first connection (no 0-Rtt there).
	require.False(t, conns[0].is0RTT())

	// Examine the second connection (the one that used 0-Rtt).
	require.True(t, conns[1].is0RTT())
}

// testDoHServerOptions allows customizing testDoHServer behavior.
type testDoHServerOptions struct {
	handler          http.Handler
	delayHandshakeH2 time.Duration
	delayHandshakeH3 time.Duration
	port             int
	http3Enabled     bool
}

// testDoHServer is an instance of a test DNS-over-HTTPS server.
type testDoHServer struct {
	// tlsConfig is the TLS configuration that is used for this server.
	tlsConfig *tls.Config

	// rootCAs is the pool with root certificates used by the test server.
	rootCAs *x509.CertPool

	// server is an HTTP/1.1 and HTTP/2 server.
	server *http.Server

	// serverH3 is an HTTP/3 server.
	serverH3 *http3.Server

	// listenerH3 that's used to serve HTTP/3.
	listenerH3 *quic.EarlyListener

	// addr is the address that this server listens to.
	addr string
}

// Shutdown stops the DoH server.
func (s *testDoHServer) Shutdown() {
	if s.server != nil {
		_ = s.server.Shutdown(context.Background())
	}

	if s.serverH3 != nil {
		_ = s.serverH3.Close()
		_ = s.listenerH3.Close()
	}
}

// startDoHServer starts a new DNS-over-HTTPS server on a random port and
// returns the instance of this server.  Depending on whether http3Enabled is
// set to true or false it will or will not initialize a HTTP/3 server.
func startDoHServer(
	t *testing.T,
	opts testDoHServerOptions,
) (s *testDoHServer) {
	tlsConfig, rootCAs := createServerTLSConfig(t, "127.0.0.1")
	handler := opts.handler
	if handler == nil {
		handler = createDoHHandler()
	}

	// Step one is to create a regular HTTP server, we'll always have it
	// running.
	server := &http.Server{
		Handler:     handler,
		ReadTimeout: time.Second,
	}

	// Listen TCP first.
	listenAddr := fmt.Sprintf("127.0.0.1:%d", opts.port)
	tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
	require.NoError(t, err)

	tcpListen, err := net.ListenTCP("tcp", tcpAddr)
	require.NoError(t, err)

	tlsConfigH2 := tlsConfig.Clone()
	tlsConfigH2.NextProtos = []string{string(HTTPVersion2), string(HTTPVersion11)}
	tlsConfigH2.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
		if opts.delayHandshakeH2 > 0 {
			time.Sleep(opts.delayHandshakeH2)
		}
		return nil, nil
	}
	tlsListen := tls.NewListener(tcpListen, tlsConfigH2)

	// Run the H1/H2 server.
	go func() {
		// TODO(ameshkov): check the error here.
		_ = server.Serve(tlsListen)
	}()

	// Get the real address that the listener now listens to.
	tcpAddr = tcpListen.Addr().(*net.TCPAddr)

	var serverH3 *http3.Server
	var listenerH3 *quic.EarlyListener

	if opts.http3Enabled {
		tlsConfigH3 := tlsConfig.Clone()
		tlsConfigH3.NextProtos = []string{string(HTTPVersion3)}
		tlsConfigH3.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
			if opts.delayHandshakeH3 > 0 {
				time.Sleep(opts.delayHandshakeH3)
			}
			return nil, nil
		}

		serverH3 = &http3.Server{
			Handler: handler,
		}

		// Listen UDP for the H3 server. Reuse the same port as was used for the
		// TCP listener.
		udpAddr, uErr := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port))
		require.NoError(t, uErr)

		// QUIC configuration with the 0-RTT support enabled by default.
		quicConfig := &quic.Config{
			RequireAddressValidation: func(net.Addr) (ok bool) {
				return true
			},
			Allow0RTT: true,
		}
		listenerH3, err = quic.ListenAddrEarly(udpAddr.String(), tlsConfigH3, quicConfig)
		require.NoError(t, err)

		// Run the H3 server.
		go func() {
			// TODO(ameshkov): check the error here.
			_ = serverH3.ServeListener(listenerH3)
		}()
	}

	return &testDoHServer{
		tlsConfig:  tlsConfig,
		rootCAs:    rootCAs,
		server:     server,
		serverH3:   serverH3,
		listenerH3: listenerH3,
		// Save the address that the server listens to.
		addr: tcpAddr.String(),
	}
}

// createDoHHandlerFunc creates a simple http.HandlerFunc that reads the
// incoming DNS message and returns the test response.
func createDoHHandlerFunc() (f http.HandlerFunc) {
	return func(w http.ResponseWriter, r *http.Request) {
		dnsParam := r.URL.Query().Get("dns")
		buf, err := base64.RawURLEncoding.DecodeString(dnsParam)
		if err != nil {
			http.Error(
				w,
				fmt.Sprintf("internal error: %s", err),
				http.StatusInternalServerError,
			)
			return
		}

		m := &dns.Msg{}
		err = m.Unpack(buf)
		if err != nil {
			http.Error(
				w,
				fmt.Sprintf("internal error: %s", err),
				http.StatusInternalServerError,
			)
			return
		}

		resp := respondToTestMessage(m)

		buf, err = resp.Pack()
		if err != nil {
			http.Error(
				w,
				fmt.Sprintf("internal error: %s", err),
				http.StatusInternalServerError,
			)
			return
		}

		w.Header().Set("Content-Type", "application/dns-message")

		_, err = w.Write(buf)
		if err != nil {
			panic(fmt.Errorf("unexpected error on writing response: %w", err))
		}
	}
}

// createDoHHandler returns a very simple http.Handler that reads the incoming
// request and returns with a test message.
func createDoHHandler() (h http.Handler) {
	mux := http.NewServeMux()
	mux.HandleFunc("/dns-query", createDoHHandlerFunc())

	return mux
}
0707010000007A000081A4000000000000000000000001650C592100001B8C000000000000000000000000000000000000002000000000dnsproxy-0.55.0/upstream/dot.gopackage upstream

import (
	"context"
	"crypto/tls"
	"fmt"
	"io"
	"net"
	"net/url"
	"os"
	"runtime"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
)

// dialTimeout is the global timeout for establishing a TLS connection.
// TODO(ameshkov): use bootstrap timeout instead.
const dialTimeout = 10 * time.Second

// dnsOverTLS implements the [Upstream] interface for the DNS-over-TLS protocol.
type dnsOverTLS struct {
	// addr is the DNS-over-TLS server URL.
	addr *url.URL

	// getDialer either returns an initialized dial handler or creates a
	// new one.
	getDialer DialerInitializer

	// tlsConf is the configuration of TLS.
	tlsConf *tls.Config

	// connsMu protects conns.
	connsMu *sync.Mutex

	// conns stores the connections ready for reuse.  Don't use [sync.Pool]
	// here, since there is no need to deallocate these connections.
	//
	// TODO(e.burkov, ameshkov):  Currently connections just stored in FILO
	// order, which eventually makes most of them unusable due to timeouts.
	// This leads to weak performance for all exchanges coming across such
	// connections.
	conns []net.Conn
}

// type check
var _ Upstream = (*dnsOverTLS)(nil)

// newDoT returns the DNS-over-TLS Upstream.
func newDoT(addr *url.URL, opts *Options) (ups Upstream, err error) {
	addPort(addr, defaultPortDoT)

	getDialer, err := newDialerInitializer(addr, opts)
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return nil, err
	}

	tlsUps := &dnsOverTLS{
		addr:      addr,
		getDialer: getDialer,
		// #nosec G402 -- TLS certificate verification could be disabled by
		// configuration.
		tlsConf: &tls.Config{
			ServerName:   addr.Hostname(),
			RootCAs:      opts.RootCAs,
			CipherSuites: opts.CipherSuites,
			// Use the default capacity for the LRU cache.  It may be useful to
			// store several caches since the user may be routed to different
			// servers in case there's load balancing on the server-side.
			ClientSessionCache:    tls.NewLRUClientSessionCache(0),
			MinVersion:            tls.VersionTLS12,
			InsecureSkipVerify:    opts.InsecureSkipVerify,
			VerifyPeerCertificate: opts.VerifyServerCertificate,
			VerifyConnection:      opts.VerifyConnection,
		},
		connsMu: &sync.Mutex{},
	}

	runtime.SetFinalizer(tlsUps, (*dnsOverTLS).Close)

	return tlsUps, nil
}

// Address implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Address() string { return p.addr.String() }

// Exchange implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
	h, err := p.getDialer()
	if err != nil {
		return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
	}

	conn, err := p.conn(h)
	if err != nil {
		return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
	}

	reply, err = p.exchangeWithConn(conn, m)
	if err != nil {
		// The pooled connection might have been closed already, see
		// https://github.com/AdguardTeam/dnsproxy/issues/3.  The following
		// connection from pool may also be malformed, so dial a new one.

		err = errors.WithDeferred(err, conn.Close())
		log.Debug("dot %s: bad conn from pool: %s", p.addr, err)

		// Retry.
		conn, err = tlsDial(h, p.tlsConf.Clone())
		if err != nil {
			return nil, fmt.Errorf(
				"dialing %s: connecting to %s: %w",
				p.addr,
				p.tlsConf.ServerName,
				err,
			)
		}

		reply, err = p.exchangeWithConn(conn, m)
		if err != nil {
			return reply, errors.WithDeferred(err, conn.Close())
		}
	}

	p.putBack(conn)

	return reply, nil
}

// Close implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Close() (err error) {
	runtime.SetFinalizer(p, nil)

	p.connsMu.Lock()
	defer p.connsMu.Unlock()

	var closeErrs []error
	for _, conn := range p.conns {
		closeErr := conn.Close()
		if closeErr != nil && isCriticalTCP(closeErr) {
			closeErrs = append(closeErrs, closeErr)
		}
	}

	if len(closeErrs) > 0 {
		return errors.List("closing tls conns", closeErrs...)
	}

	return nil
}

// conn returns the first available connection from the pool if there is any, or
// dials a new one otherwise.
func (p *dnsOverTLS) conn(h bootstrap.DialHandler) (conn net.Conn, err error) {
	// Dial a new connection outside the lock, if needed.
	defer func() {
		if conn == nil {
			conn, err = tlsDial(h, p.tlsConf.Clone())
			err = errors.Annotate(err, "connecting to %s: %w", p.tlsConf.ServerName)
		}
	}()

	p.connsMu.Lock()
	defer p.connsMu.Unlock()

	l := len(p.conns)
	if l == 0 {
		return nil, nil
	}

	p.conns, conn = p.conns[:l-1], p.conns[l-1]

	err = conn.SetDeadline(time.Now().Add(dialTimeout))
	if err != nil {
		log.Debug("dot upstream: setting deadline to conn from pool: %s", err)

		// If deadLine can't be updated it means that connection was already
		// closed.
		return nil, nil
	}

	log.Debug("dot upstream: using existing conn %s", conn.RemoteAddr())

	return conn, nil
}

func (p *dnsOverTLS) putBack(conn net.Conn) {
	p.connsMu.Lock()
	defer p.connsMu.Unlock()

	p.conns = append(p.conns, conn)
}

// exchangeWithConn tries to exchange the query using conn.
func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) {
	addr := p.Address()

	logBegin(addr, networkTCP, m)
	defer func() { logFinish(addr, networkTCP, err) }()

	dnsConn := dns.Conn{Conn: conn}

	err = dnsConn.WriteMsg(m)
	if err != nil {
		return nil, fmt.Errorf("sending request to %s: %w", addr, err)
	}

	reply, err = dnsConn.ReadMsg()
	if err != nil {
		return nil, fmt.Errorf("reading response from %s: %w", addr, err)
	} else if reply.Id != m.Id {
		return reply, dns.ErrId
	}

	return reply, err
}

// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func tlsDial(dialContext bootstrap.DialHandler, conf *tls.Config) (c *tls.Conn, err error) {
	// We're using bootstrapped address instead of what's passed to the
	// function.
	rawConn, err := dialContext(context.Background(), networkTCP, "")
	if err != nil {
		return nil, err
	}

	// We want the timeout to cover the whole process: TCP connection and TLS
	// handshake dialTimeout will be used as connection deadLine.
	conn := tls.Client(rawConn, conf)
	err = conn.SetDeadline(time.Now().Add(dialTimeout))
	if err != nil {
		// Must not happen in normal circumstances.
		panic(fmt.Errorf("dnsproxy: tls dial: setting deadline: %w", err))
	}

	err = conn.Handshake()
	if err != nil {
		return nil, errors.WithDeferred(err, conn.Close())
	}

	return conn, nil
}

// isCriticalTCP returns true if err isn't an expected error in terms of closing
// the TCP connection.
func isCriticalTCP(err error) (ok bool) {
	var netErr net.Error
	if errors.As(err, &netErr) && netErr.Timeout() {
		return false
	}

	switch {
	case
		errors.Is(err, io.EOF),
		errors.Is(err, net.ErrClosed),
		errors.Is(err, os.ErrDeadlineExceeded),
		isConnBroken(err):
		return false
	default:
		return true
	}
}
0707010000007B000081A4000000000000000000000001650C592100001D2D000000000000000000000000000000000000002500000000dnsproxy-0.55.0/upstream/dot_test.gopackage upstream

import (
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"net"
	"net/url"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestUpstream_dnsOverTLS(t *testing.T) {
	srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		resp := respondToTestMessage(req)

		err := w.WriteMsg(resp)

		pt := testutil.PanicT{}
		require.NoError(pt, err)
	})

	// Create a DoT upstream that we'll be testing.
	addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
	u, err := AddressToUpstream(addr, &Options{InsecureSkipVerify: true})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Test that it responds properly.
	for i := 0; i < 10; i++ {
		checkUpstream(t, u, addr)
	}
}

func TestUpstream_dnsOverTLS_race(t *testing.T) {
	const count = 10

	srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		resp := respondToTestMessage(req)

		err := w.WriteMsg(resp)

		pt := testutil.PanicT{}
		require.NoError(pt, err)
	})

	// Creating a DoT upstream that we will be testing.
	addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
	u, err := AddressToUpstream(addr, &Options{InsecureSkipVerify: true})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Use this upstream from multiple goroutines in parallel.
	wg := sync.WaitGroup{}
	for i := 0; i < count; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()

			pt := testutil.PanicT{}

			req := createTestMessage()
			resp, uErr := u.Exchange(req)
			require.NoError(pt, uErr)
			requireResponse(pt, req, resp)
		}()
	}

	wg.Wait()
}

// TODO(e.burkov, a.garipov):  Add to golibs and use here some kind of helper
// for type assertion of interface types.
func TestUpstream_dnsOverTLS_poolReconnect(t *testing.T) {
	srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(req)))
	})

	// This var is used to store the last connection state in order to check
	// if session resumption works as expected.
	var lastState tls.ConnectionState

	// Init the upstream to the test DoT server that also keeps track of the
	// session resumptions.
	addr := (&url.URL{
		Scheme: "tls",
		Host:   srv.srv.Listener.Addr().String(),
	}).String()
	u, err := AddressToUpstream(addr, &Options{
		InsecureSkipVerify: true,
		VerifyConnection: func(state tls.ConnectionState) error {
			lastState = state

			return nil
		},
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	p := testutil.RequireTypeAssert[*dnsOverTLS](t, u)

	// Send the first test message.
	req := createTestMessage()
	reply, err := u.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, reply)

	// Now let's close the pooled connection.
	require.Len(t, p.conns, 1)
	conn := p.conns[0]
	require.NoError(t, conn.Close())

	// Send the second test message.
	req = createTestMessage()
	reply, err = u.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, reply)

	// Now assert that the number of connections in the pool is not changed.
	require.Len(t, p.conns, 1)
	assert.NotSame(t, conn, p.conns[0])

	// Check that the session was resumed on the last attempt.
	assert.True(t, lastState.DidResume)
}

func TestUpstream_dnsOverTLS_poolDeadline(t *testing.T) {
	srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(req)))
	})

	// Create a DoT upstream that we'll be testing.
	addr := (&url.URL{
		Scheme: "tls",
		Host:   srv.srv.Listener.Addr().String(),
	}).String()
	u, err := AddressToUpstream(addr, &Options{
		InsecureSkipVerify: true,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Send the first test message.
	req := createTestMessage()
	response, err := u.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, response)

	p := testutil.RequireTypeAssert[*dnsOverTLS](t, u)

	// Now let's get connection from the pool and use it again.
	require.Len(t, p.conns, 1)
	conn := p.conns[0]

	dialHandler, err := p.getDialer()
	require.NoError(t, err)

	usedConn, err := p.conn(dialHandler)
	require.NoError(t, err)
	require.Same(t, usedConn, conn)

	response, err = p.exchangeWithConn(conn, req)
	require.NoError(t, err)
	requireResponse(t, req, response)

	// Update the connection's deadLine.
	err = conn.SetDeadline(time.Now().Add(10 * time.Hour))
	require.NoError(t, err)

	p.putBack(conn)

	// Get connection from the pool and reuse it.
	require.Len(t, p.conns, 1)
	conn = p.conns[0]

	usedConn, err = p.conn(dialHandler)
	require.NoError(t, err)
	require.Same(t, usedConn, conn)

	response, err = p.exchangeWithConn(usedConn, req)
	require.NoError(t, err)
	requireResponse(t, req, response)

	// Set connection's deadLine to the past and try to reuse it.
	err = usedConn.SetDeadline(time.Now().Add(-10 * time.Hour))
	require.NoError(t, err)

	// Connection with expired deadLine can't be used.
	response, err = p.exchangeWithConn(usedConn, req)
	require.Error(t, err)
	require.Nil(t, response)
}

// testDoTServer is a test DNS-over-TLS server that can be used in unit-tests.
type testDoTServer struct {
	// srv is the *dns.Server instance that listens for DoT requests.
	srv *dns.Server

	// tlsConfig is the TLS configuration that is used for this server.
	tlsConfig *tls.Config

	// rootCAs is the pool with root certificates used by the test server.
	rootCAs *x509.CertPool

	// port to which the server listens to.
	port int
}

// type check
var _ io.Closer = (*testDoTServer)(nil)

// startDoTServer starts *testDoTServer on a random port.
//
// TODO(e.burkov):  Also return address?
func startDoTServer(tb testing.TB, handler dns.HandlerFunc) (s *testDoTServer) {
	tb.Helper()

	tcpListener, err := net.Listen("tcp", "127.0.0.1:0")
	require.NoError(tb, err)

	tlsConfig, rootCAs := createServerTLSConfig(tb, "127.0.0.1")
	tlsListener := tls.NewListener(tcpListener, tlsConfig)

	srv := &dns.Server{
		Listener:  tlsListener,
		TLSConfig: tlsConfig,
		Net:       "tls",
		Handler:   handler,
	}

	go func() {
		pt := testutil.PanicT{}
		require.NoError(pt, srv.ActivateAndServe())
	}()

	s = &testDoTServer{
		srv:       srv,
		tlsConfig: tlsConfig,
		rootCAs:   rootCAs,
		port:      tcpListener.Addr().(*net.TCPAddr).Port,
	}
	testutil.CleanupAndRequireSuccess(tb, s.Close)

	return s
}

// Close implements the io.Closer interface for *testDoTServer.
func (s *testDoTServer) Close() error {
	return s.srv.Shutdown()
}

func BenchmarkDoTUpstream(b *testing.B) {
	srv := startDoTServer(b, func(w dns.ResponseWriter, m *dns.Msg) {
		err := w.WriteMsg(respondToTestMessage(m))
		require.NoError(testutil.PanicT{}, err)
	})

	addr := (&url.URL{
		Scheme: "tls",
		Host:   srv.srv.Listener.Addr().String(),
	}).String()

	u, err := AddressToUpstream(addr, &Options{
		InsecureSkipVerify: true,
	})
	require.NoError(b, err)
	testutil.CleanupAndRequireSuccess(b, u.Close)

	reqChan := make(chan *dns.Msg, 64)
	go func() {
		for {
			reqChan <- createTestMessage()
		}
	}()

	// Wait for channel to fill.
	require.Eventually(b, func() bool {
		return len(reqChan) == cap(reqChan)
	}, time.Second, time.Millisecond)

	b.Run("exchange_p", func(b *testing.B) {
		b.ResetTimer()
		b.ReportAllocs()

		b.RunParallel(func(p *testing.PB) {
			for p.Next() {
				_, _ = u.Exchange(<-reqChan)
			}
		})
	})
}
0707010000007C000081A4000000000000000000000001650C592100000152000000000000000000000000000000000000002500000000dnsproxy-0.55.0/upstream/dot_unix.go//go:build darwin || freebsd || linux || openbsd || netbsd

package upstream

import (
	"github.com/AdguardTeam/golibs/errors"
	"golang.org/x/sys/unix"
)

// isConnBroken returns true if err means that a connection is broken.
func isConnBroken(err error) (ok bool) {
	return errors.Is(err, unix.EPIPE) || errors.Is(err, unix.ETIMEDOUT)
}
0707010000007D000081A4000000000000000000000001650C592100000141000000000000000000000000000000000000002800000000dnsproxy-0.55.0/upstream/dot_windows.go//go:build windows

package upstream

import (
	"github.com/AdguardTeam/golibs/errors"
	"golang.org/x/sys/windows"
)

// isConnBroken returns true if err means that a connection is broken.
func isConnBroken(err error) (ok bool) {
	return errors.Is(err, windows.WSAECONNABORTED) || errors.Is(err, windows.WSAECONNRESET)
}
0707010000007E000081A4000000000000000000000001650C5921000010A4000000000000000000000000000000000000002500000000dnsproxy-0.55.0/upstream/parallel.gopackage upstream

import (
	"context"
	"net/netip"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
)

// ErrNoUpstreams is returned from the methods that expect at least a single
// upstream to work with when no upstreams specified.
const ErrNoUpstreams errors.Error = "no upstream specified"

// ExchangeParallel returns the dirst successful response from one of u.  It
// returns an error if all upstreams failed to exchange the request.
func ExchangeParallel(u []Upstream, req *dns.Msg) (reply *dns.Msg, resolved Upstream, err error) {
	upsNum := len(u)
	switch upsNum {
	case 0:
		return nil, nil, ErrNoUpstreams
	case 1:
		reply, err = exchangeAndLog(u[0], req)

		return reply, u[0], err
	default:
		// Go on.
	}

	ch := make(chan *exchangeResult, upsNum)

	for _, f := range u {
		go exchangeAsync(f, req, ch)
	}

	errs := []error{}
	for range u {
		rep := <-ch
		if rep.err != nil {
			errs = append(errs, rep.err)

			continue
		}

		if rep.reply != nil {
			return rep.reply, rep.upstream, nil
		}
	}

	if len(errs) == 0 {
		return nil, nil, errors.Error("none of upstream servers responded")
	}

	// TODO(e.burkov):  Use [errors.Join] in Go 1.20.
	return nil, nil, errors.List("all upstreams failed to respond", errs...)
}

// ExchangeAllResult is the successful result of [ExchangeAll] for a single
// upstream.
type ExchangeAllResult struct {
	// Resp is the response DNS request resolved into.
	Resp *dns.Msg

	// Upstream is the upstream that successfully resolved the request.
	Upstream Upstream
}

// ExchangeAll retunrs the responses from all of u.  It returns an error only if
// all upstreams failed to exchange the request.
func ExchangeAll(ups []Upstream, req *dns.Msg) (res []ExchangeAllResult, err error) {
	upsl := len(ups)
	switch upsl {
	case 0:
		return nil, ErrNoUpstreams
	case 1:
		var reply *dns.Msg
		reply, err = exchangeAndLog(ups[0], req)
		if err != nil {
			return nil, err
		} else if reply == nil {
			return nil, errors.Error("no reply")
		}

		return []ExchangeAllResult{{Upstream: ups[0], Resp: reply}}, nil
	default:
		// Go on.
	}

	res = make([]ExchangeAllResult, 0, upsl)
	errs := make([]error, 0, upsl)
	resCh := make(chan *exchangeResult, upsl)

	// Start exchanging concurrently.
	for _, u := range ups {
		go exchangeAsync(u, req, resCh)
	}

	// Wait for all exchanges to finish.
	for range ups {
		rep := <-resCh
		if rep.err != nil {
			errs = append(errs, rep.err)

			continue
		}

		if rep.reply == nil {
			errs = append(errs, errors.Error("no reply"))

			continue
		}

		res = append(res, ExchangeAllResult{
			Resp:     rep.reply,
			Upstream: rep.upstream,
		})
	}

	if len(errs) == upsl {
		// TODO(e.burkov):  Use [errors.Join] in Go 1.20.
		return res, errors.List("all upstreams failed to exchange", errs...)
	}

	return res, nil
}

// exchangeResult represents the result of DNS exchange.
type exchangeResult = struct {
	// upstream is the Upstream that successfully resolved the request.
	upstream Upstream

	// reply is the response DNS request resolved into.
	reply *dns.Msg

	// err is the error that occurred while resolving the request.
	err error
}

// exchangeAsync tries to resolve DNS request with one upstream and sends the
// result to respCh.
func exchangeAsync(u Upstream, req *dns.Msg, respCh chan *exchangeResult) {
	res := &exchangeResult{upstream: u}

	res.reply, res.err = exchangeAndLog(u, req)

	respCh <- res
}

// exchangeAndLog wraps the [Upstream.Exchange] method with logging.
func exchangeAndLog(u Upstream, req *dns.Msg) (resp *dns.Msg, err error) {
	addr := u.Address()
	req = req.Copy()

	start := time.Now()
	reply, err := u.Exchange(req)
	elapsed := time.Since(start)

	if q := &req.Question[0]; err == nil {
		log.Debug("upstream %s exchanged %s successfully in %s", addr, q, elapsed)
	} else {
		log.Debug("upstream %s failed to exchange %s in %s: %s", addr, q, elapsed, err)
	}

	return reply, err
}

// LookupParallel tries to lookup for ip of host with all resolvers
// concurrently.
func LookupParallel(ctx context.Context, resolvers []Resolver, host string) ([]netip.Addr, error) {
	return bootstrap.LookupParallel(ctx, resolvers, host)
}
0707010000007F000081A4000000000000000000000001650C592100000F17000000000000000000000000000000000000002A00000000dnsproxy-0.55.0/upstream/parallel_test.gopackage upstream

import (
	"context"
	"fmt"
	"net"
	"testing"
	"time"

	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
)

const (
	timeout = 5 * time.Second
)

// TestExchangeParallel launches several parallel exchanges
func TestExchangeParallel(t *testing.T) {
	upstreams := []Upstream{}
	upstreamList := []string{"1.2.3.4:55", "8.8.8.1", "8.8.8.8:53"}

	for _, s := range upstreamList {
		u, err := AddressToUpstream(s, &Options{Timeout: timeout})
		if err != nil {
			t.Fatalf("cannot create upstream: %s", err)
		}
		upstreams = append(upstreams, u)
	}

	req := createTestMessage()
	start := time.Now()
	resp, u, err := ExchangeParallel(upstreams, req)
	if err != nil {
		t.Fatalf("no response from test upstreams: %s", err)
	}

	if u.Address() != "8.8.8.8:53" {
		t.Fatalf("shouldn't happen. This upstream can't resolve DNS request: %s", u.Address())
	}

	requireResponse(t, req, resp)
	elapsed := time.Since(start)
	if elapsed > timeout {
		t.Fatalf("exchange took more time than the configured timeout: %v", elapsed)
	}
}

func TestLookupParallel(t *testing.T) {
	resolvers := []Resolver{}
	bootstraps := []string{"1.2.3.4:55", "8.8.8.1:555", "8.8.8.8:53"}

	for _, boot := range bootstraps {
		resolver, _ := NewUpstreamResolver(boot, &Options{Timeout: timeout})
		resolvers = append(resolvers, resolver)
	}

	ctx, cancel := context.WithTimeout(context.TODO(), timeout)
	defer cancel()

	start := time.Now()
	answer, err := LookupParallel(ctx, resolvers, "google.com")
	if err != nil || answer == nil {
		t.Fatalf("failed to lookup %s", err)
	}

	elapsed := time.Since(start)
	if elapsed > timeout {
		t.Fatalf("lookup took more time than the configured timeout: %v", elapsed)
	}
}

func TestLookupParallelEmpty(t *testing.T) {
	u1 := testUpstream{}
	u2 := testUpstream{}

	resolvers := []Resolver{}
	resolvers = append(resolvers, &upstreamResolver{Upstream: &u1})
	resolvers = append(resolvers, &upstreamResolver{Upstream: &u2})

	ctx, cancel := context.WithTimeout(context.TODO(), timeout)
	defer cancel()
	a, err := LookupParallel(ctx, resolvers, "google.com")
	assert.Nil(t, err)
	assert.Equal(t, 0, len(a))
}

func TestExchangeParallelEmpty(t *testing.T) {
	u1 := testUpstream{}
	u1.empty = true
	u2 := testUpstream{}
	u2.empty = true
	u := []Upstream{&u1, &u2}

	req := createTestMessage()
	a, up, err := ExchangeParallel(u, req)
	assert.NotNil(t, err)
	assert.Nil(t, a)
	assert.Nil(t, up)
}

type testUpstream struct {
	a     net.IP
	err   bool
	empty bool
	sleep time.Duration // a delay before response
}

// type check
var _ Upstream = (*testUpstream)(nil)

// Exchange implements the Upstream interface for *testUpstream.
func (u *testUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	if u.sleep != 0 {
		time.Sleep(u.sleep)
	}

	if u.empty {
		return nil, nil
	}

	resp = &dns.Msg{}
	resp.SetReply(req)

	if len(u.a) != 0 {
		a := dns.A{}
		a.A = u.a
		resp.Answer = append(resp.Answer, &a)
	}

	if u.err {
		return nil, fmt.Errorf("upstream error")
	}

	return resp, nil
}

// Address implements the Upstream interface for *testUpstream.
func (u *testUpstream) Address() (addr string) {
	return ""
}

// Close implements the Upstream interface for *testUpstream.
func (u *testUpstream) Close() (err error) {
	return nil
}

func TestExchangeAll(t *testing.T) {
	u1 := testUpstream{}
	u1.a = net.ParseIP("1.1.1.1")
	u1.sleep = 100 * time.Millisecond

	u2 := testUpstream{}
	u2.err = true

	u3 := testUpstream{}
	u3.a = net.ParseIP("3.3.3.3")

	ups := []Upstream{&u1, &u2, &u3}
	req := createHostTestMessage("test.org")
	res, err := ExchangeAll(ups, req)
	assert.True(t, err == nil)
	assert.True(t, len(res) == 2)

	a := res[0].Resp.Answer[0].(*dns.A)
	assert.True(t, a.A.To4().Equal(net.ParseIP("3.3.3.3").To4()))

	a = res[1].Resp.Answer[0].(*dns.A)
	assert.True(t, a.A.To4().Equal(net.ParseIP("1.1.1.1").To4()))
}
07070100000080000081A4000000000000000000000001650C592100001549000000000000000000000000000000000000002200000000dnsproxy-0.55.0/upstream/plain.gopackage upstream

import (
	"context"
	"fmt"
	"io"
	"net"
	"net/url"
	"strings"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
)

// network is the semantic type alias of the network to pass to dialing
// functions.  It's either [networkUDP] or [networkTCP].  It may also be used as
// URL scheme for plain upstreams.
type network = string

const (
	// networkUDP is the UDP network.
	networkUDP network = "udp"

	// networkTCP is the TCP network.
	networkTCP network = "tcp"
)

// plainDNS implements the [Upstream] interface for the regular DNS protocol.
type plainDNS struct {
	// addr is the DNS server URL.  Scheme is always "udp" or "tcp".
	addr *url.URL

	// getDialer either returns an initialized dial handler or creates a new
	// one.
	getDialer DialerInitializer

	// net is the network of the connections.
	net network

	// timeout is the timeout for DNS requests.
	timeout time.Duration
}

// type check
var _ Upstream = &plainDNS{}

// newPlain returns the plain DNS Upstream.  addr.Scheme should be either "udp"
// or "tcp".
func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) {
	switch addr.Scheme {
	case networkUDP, networkTCP:
		// Go on.
	default:
		return nil, fmt.Errorf("unsupported url scheme: %s", addr.Scheme)
	}

	addPort(addr, defaultPortPlain)

	getDialer, err := newDialerInitializer(addr, opts)
	if err != nil {
		return nil, err
	}

	return &plainDNS{
		addr:      addr,
		getDialer: getDialer,
		net:       addr.Scheme,
		timeout:   opts.Timeout,
	}, nil
}

// Address implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Address() string {
	switch p.net {
	case networkUDP:
		return p.addr.Host
	case networkTCP:
		return p.addr.String()
	default:
		panic(fmt.Sprintf("unexpected network: %s", p.net))
	}
}

// dialExchange performs a DNS exchange with the specified dial handler.
// network must be either [networkUDP] or [networkTCP].
func (p *plainDNS) dialExchange(
	network network,
	dial bootstrap.DialHandler,
	req *dns.Msg,
) (resp *dns.Msg, err error) {
	addr := p.Address()
	client := &dns.Client{Timeout: p.timeout}

	conn := &dns.Conn{}
	if network == networkUDP {
		conn.UDPSize = dns.MinMsgSize
	}

	logBegin(addr, network, req)
	defer func() { logFinish(addr, network, err) }()

	ctx := context.Background()
	conn.Conn, err = dial(ctx, network, "")
	if err != nil {
		return nil, fmt.Errorf("dialing %s over %s: %w", p.addr.Host, network, err)
	}
	defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn)

	resp, _, err = client.ExchangeWithConn(req, conn)
	if isExpectedConnErr(err) {
		conn.Conn, err = dial(ctx, network, "")
		if err != nil {
			return nil, fmt.Errorf("dialing %s over %s again: %w", p.addr.Host, network, err)
		}
		defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn)

		resp, _, err = client.ExchangeWithConn(req, conn)
	}

	if err != nil {
		return resp, fmt.Errorf("exchanging with %s over %s: %w", addr, network, err)
	}

	return resp, validatePlainResponse(req, resp)
}

// isExpectedConnErr returns true if the error is expected.  In this case,
// we will make a second attempt to process the request.
func isExpectedConnErr(err error) (is bool) {
	var netErr net.Error

	return err != nil && (errors.As(err, &netErr) || errors.Is(err, io.EOF))
}

// Exchange implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
	dial, err := p.getDialer()
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return nil, err
	}

	addr := p.Address()

	resp, err = p.dialExchange(p.net, dial, req)
	if p.net != networkUDP {
		// The network is already TCP.
		return resp, err
	}

	if resp == nil {
		// There is likely an error with the upstream.
		return resp, err
	}

	if errors.Is(err, errQuestion) {
		// The upstream responds with malformed messages, so try TCP.
		log.Debug("plain %s: %s, using tcp", addr, err)

		return p.dialExchange(networkTCP, dial, req)
	} else if resp.Truncated {
		// Fallback to TCP on truncated responses.
		log.Debug("plain %s: resp for %s is truncated, using tcp", &req.Question[0], addr)

		return p.dialExchange(networkTCP, dial, req)
	}

	// There is either no error or the error isn't related to the received
	// message.
	return resp, err
}

// Close implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Close() (err error) {
	return nil
}

// errQuestion is returned when a message has malformed question section.
const errQuestion errors.Error = "bad question section"

// validatePlainResponse validates resp from an upstream DNS server for
// compliance with req.  Any error returned wraps [ErrQuestion], since it
// essentially validates the question section of resp.
func validatePlainResponse(req, resp *dns.Msg) (err error) {
	if qlen := len(resp.Question); qlen != 1 {
		return fmt.Errorf("%w: only 1 question allowed; got %d", errQuestion, qlen)
	}

	reqQ, respQ := req.Question[0], resp.Question[0]

	if reqQ.Qtype != respQ.Qtype {
		return fmt.Errorf("%w: mismatched type %s", errQuestion, dns.Type(respQ.Qtype))
	}

	// Compare the names case-insensitively, just like CoreDNS does.
	if !strings.EqualFold(reqQ.Name, respQ.Name) {
		return fmt.Errorf("%w: mismatched name %q", errQuestion, respQ.Name)
	}

	return nil
}
07070100000081000081A4000000000000000000000001650C5921000011AB000000000000000000000000000000000000002700000000dnsproxy-0.55.0/upstream/plain_test.gopackage upstream

import (
	"fmt"
	"io"
	"net"
	"sync/atomic"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestUpstream_plainDNS(t *testing.T) {
	srv := startDNSServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
		resp := respondToTestMessage(req)

		err := w.WriteMsg(resp)

		pt := testutil.PanicT{}
		require.NoError(pt, err)
	})
	testutil.CleanupAndRequireSuccess(t, srv.Close)

	addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
	u, err := AddressToUpstream(addr, &Options{})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	for i := 0; i < 10; i++ {
		checkUpstream(t, u, addr)
	}
}

func TestUpstream_plainDNS_badID(t *testing.T) {
	req := createTestMessage()
	badIDResp := respondToTestMessage(req)
	badIDResp.Id++

	srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) {
		require.NoError(testutil.PanicT{}, w.WriteMsg(badIDResp))
	})
	testutil.CleanupAndRequireSuccess(t, srv.Close)

	addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
	u, err := AddressToUpstream(addr, &Options{
		// Use a shorter timeout to speed up the test.
		Timeout: 100 * time.Millisecond,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	resp, err := u.Exchange(req)

	var netErr net.Error
	require.ErrorAs(t, err, &netErr)

	assert.True(t, netErr.Timeout())
	assert.Nil(t, resp)
}

func TestUpstream_plainDNS_fallbackToTCP(t *testing.T) {
	req := createTestMessage()
	goodResp := respondToTestMessage(req)

	truncResp := goodResp.Copy()
	truncResp.Truncated = true

	badQNameResp := goodResp.Copy()
	badQNameResp.Question[0].Name = "bad." + req.Question[0].Name

	badQTypeResp := goodResp.Copy()
	badQTypeResp.Question[0].Qtype = dns.TypeCNAME

	testCases := []struct {
		udpResp *dns.Msg
		name    string
		wantUDP int
		wantTCP int
	}{{
		udpResp: goodResp,
		name:    "all_right",
		wantUDP: 1,
		wantTCP: 0,
	}, {
		udpResp: truncResp,
		name:    "truncated_response",
		wantUDP: 1,
		wantTCP: 1,
	}, {
		udpResp: badQNameResp,
		name:    "bad_qname",
		wantUDP: 1,
		wantTCP: 1,
	}, {
		udpResp: badQTypeResp,
		name:    "bad_qtype",
		wantUDP: 1,
		wantTCP: 1,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			var udpReqNum, tcpReqNum atomic.Uint32
			srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) {
				var resp *dns.Msg
				if w.RemoteAddr().Network() == networkUDP {
					udpReqNum.Add(1)
					resp = tc.udpResp
				} else {
					tcpReqNum.Add(1)
					resp = goodResp
				}

				require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
			})
			testutil.CleanupAndRequireSuccess(t, srv.Close)

			addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
			u, err := AddressToUpstream(addr, &Options{
				// Use a shorter timeout to speed up the test.
				Timeout: 100 * time.Millisecond,
			})
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			resp, err := u.Exchange(req)
			require.NoError(t, err)
			requireResponse(t, req, resp)

			assert.Equal(t, tc.wantUDP, int(udpReqNum.Load()))
			assert.Equal(t, tc.wantTCP, int(tcpReqNum.Load()))
		})
	}
}

// testDNSServer is a simple DNS server that can be used in unit-tests.
type testDNSServer struct {
	udpListener net.PacketConn
	tcpListener net.Listener
	udpSrv      *dns.Server
	tcpSrv      *dns.Server
	port        int
}

// type check
var _ io.Closer = (*testDNSServer)(nil)

// startDNSServer a test DNS server.
func startDNSServer(t *testing.T, handler dns.HandlerFunc) (s *testDNSServer) {
	t.Helper()

	s = &testDNSServer{}

	udpListener, err := net.ListenPacket("udp", "127.0.0.1:0")
	require.NoError(t, err)

	s.port = udpListener.LocalAddr().(*net.UDPAddr).Port
	s.udpListener = udpListener

	s.tcpListener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", s.port))
	require.NoError(t, err)

	s.udpSrv = &dns.Server{
		PacketConn: s.udpListener,
		Handler:    handler,
	}

	s.tcpSrv = &dns.Server{
		Listener: s.tcpListener,
		Handler:  handler,
	}

	go func() {
		pt := testutil.PanicT{}
		require.NoError(pt, s.udpSrv.ActivateAndServe())
	}()

	go func() {
		pt := testutil.PanicT{}
		require.NoError(pt, s.tcpSrv.ActivateAndServe())
	}()

	return s
}

// Close implements the io.Closer interface for *testDNSServer.
func (s *testDNSServer) Close() (err error) {
	udpErr := s.udpSrv.Shutdown()
	tcpErr := s.tcpSrv.Shutdown()

	return errors.WithDeferred(udpErr, tcpErr)
}
07070100000082000081A4000000000000000000000001650C592100003CF1000000000000000000000000000000000000002100000000dnsproxy-0.55.0/upstream/quic.gopackage upstream

import (
	"context"
	"crypto/tls"
	"fmt"
	"net"
	"net/url"
	"runtime"
	"sync"
	"time"

	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
)

const (
	// QUICCodeNoError is used when the connection or stream needs to be closed,
	// but there is no error to signal.
	QUICCodeNoError = quic.ApplicationErrorCode(0)

	// QUICCodeInternalError signals that the DoQ implementation encountered
	// an internal error and is incapable of pursuing the transaction or the
	// connection.
	QUICCodeInternalError = quic.ApplicationErrorCode(1)

	// QUICKeepAlivePeriod is the value that we pass to *quic.Config and that
	// controls the period with with keep-alive frames are being sent to the
	// connection. We set it to 20s as it would be in the quic-go@v0.27.1 with
	// KeepAlive field set to true This value is specified in
	// https://pkg.go.dev/github.com/quic-go/quic-go/internal/protocol#MaxKeepAliveInterval.
	//
	// TODO(ameshkov):  Consider making it configurable.
	QUICKeepAlivePeriod = time.Second * 20

	// NextProtoDQ is the ALPN token for DoQ. During the connection establishment,
	// DNS/QUIC support is indicated by selecting the ALPN token "doq" in the
	// crypto handshake.
	//
	// See https://datatracker.ietf.org/doc/rfc9250.
	NextProtoDQ = "doq"
)

// compatProtoDQ is a list of ALPN tokens used by a QUIC connection.
// NextProtoDQ is the latest draft version supported by dnsproxy, but it also
// includes previous drafts.
var compatProtoDQ = []string{NextProtoDQ, "doq-i00", "dq", "doq-i02"}

// dnsOverQUIC implements the [Upstream] interface for the DNS-over-QUIC
// protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html).
type dnsOverQUIC struct {
	// getDialer either returns an initialized dial handler or creates a new
	// one.
	getDialer DialerInitializer

	// addr is the DNS-over-QUIC server URL.
	addr *url.URL

	// tlsConf is the configuration of TLS.
	tlsConf *tls.Config

	// quicConfig is the QUIC configuration that is used for establishing
	// connections to the upstream.  This configuration includes the TokenStore
	// that needs to be stored for the lifetime of dnsOverQUIC since we can
	// re-create the connection.
	quicConfig *quic.Config

	// conn is the current active QUIC connection.  It can be closed and
	// re-opened when needed.
	conn quic.Connection

	// bytesPool is a *sync.Pool we use to store byte buffers in.  These byte
	// buffers are used to read responses from the upstream.
	bytesPool *sync.Pool

	// quicConfigMu protects quicConfig.
	quicConfigMu sync.Mutex

	// connMu protects conn.
	connMu sync.RWMutex

	// bytesPoolGuard protects bytesPool.
	bytesPoolMu sync.Mutex

	// timeout is the timeout for the upstream connection.
	timeout time.Duration
}

// type check
var _ Upstream = (*dnsOverQUIC)(nil)

// newDoQ returns the DNS-over-QUIC Upstream.
func newDoQ(addr *url.URL, opts *Options) (u Upstream, err error) {
	addPort(addr, defaultPortDoQ)

	getDialer, err := newDialerInitializer(addr, opts)
	if err != nil {
		return nil, err
	}

	u = &dnsOverQUIC{
		getDialer: getDialer,
		addr:      addr,
		quicConfig: &quic.Config{
			KeepAlivePeriod: QUICKeepAlivePeriod,
			TokenStore:      newQUICTokenStore(),
			Tracer:          opts.QUICTracer,
		},
		// #nosec G402 -- TLS certificate verification could be disabled by
		// configuration.
		tlsConf: &tls.Config{
			ServerName:   addr.Hostname(),
			RootCAs:      opts.RootCAs,
			CipherSuites: opts.CipherSuites,
			// Use the default capacity for the LRU cache.  It may be useful to
			// store several caches since the user may be routed to different
			// servers in case there's load balancing on the server-side.
			ClientSessionCache:    tls.NewLRUClientSessionCache(0),
			MinVersion:            tls.VersionTLS12,
			InsecureSkipVerify:    opts.InsecureSkipVerify,
			VerifyPeerCertificate: opts.VerifyServerCertificate,
			VerifyConnection:      opts.VerifyConnection,
			NextProtos:            compatProtoDQ,
		},
		timeout: opts.Timeout,
	}

	runtime.SetFinalizer(u, (*dnsOverQUIC).Close)

	return u, nil
}

// Address implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Address() string { return p.addr.String() }

// Exchange implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
	// When sending queries over a QUIC connection, the DNS Message ID MUST be
	// set to zero.
	id := m.Id
	m.Id = 0
	defer func() {
		// Restore the original ID to not break compatibility with proxies.
		m.Id = id
		if resp != nil {
			resp.Id = id
		}
	}()

	// Check if there was already an active conn before sending the request.
	// We'll only attempt to re-connect if there was one.
	hasConnection := p.hasConnection()

	// Make the first attempt to send the DNS query.
	resp, err = p.exchangeQUIC(m)

	// Make up to 2 attempts to re-open the QUIC connection and send the request
	// again.  There are several cases where this workaround is necessary to
	// make DoQ usable.  We need to make 2 attempts in the case when the
	// connection was closed (due to inactivity for example) AND the server
	// refuses to open a 0-RTT connection.
	for i := 0; hasConnection && p.shouldRetry(err) && i < 2; i++ {
		log.Debug("re-creating the QUIC connection and retrying due to %v", err)

		// Close the active connection to make sure we'll try to re-connect.
		p.closeConnWithError(err)

		// Retry sending the request.
		resp, err = p.exchangeQUIC(m)
	}

	if err != nil {
		// If we're unable to exchange messages, make sure the connection is
		// closed and signal about an internal error.
		p.closeConnWithError(err)
	}

	return resp, err
}

// Close implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Close() (err error) {
	p.connMu.Lock()
	defer p.connMu.Unlock()

	runtime.SetFinalizer(p, nil)

	if p.conn != nil {
		err = p.conn.CloseWithError(QUICCodeNoError, "")
	}

	return err
}

// exchangeQUIC attempts to open a QUIC connection, send the DNS message
// through it and return the response it got from the server.
func (p *dnsOverQUIC) exchangeQUIC(m *dns.Msg) (resp *dns.Msg, err error) {
	var conn quic.Connection
	conn, err = p.getConnection(true)
	if err != nil {
		return nil, err
	}

	var buf []byte
	buf, err = m.Pack()
	if err != nil {
		return nil, fmt.Errorf("failed to pack DNS message for DoQ: %w", err)
	}

	var stream quic.Stream
	stream, err = p.openStream(conn)
	if err != nil {
		return nil, err
	}

	_, err = stream.Write(proxyutil.AddPrefix(buf))
	if err != nil {
		return nil, fmt.Errorf("failed to write to a QUIC stream: %w", err)
	}

	// The client MUST send the DNS query over the selected stream, and MUST
	// indicate through the STREAM FIN mechanism that no further data will
	// be sent on that stream. Note, that stream.Close() closes the
	// write-direction of the stream, but does not prevent reading from it.
	_ = stream.Close()

	return p.readMsg(stream)
}

// shouldRetry checks what error we received and decides whether it is required
// to re-open the connection and retry sending the request.
func (p *dnsOverQUIC) shouldRetry(err error) (ok bool) {
	return isQUICRetryError(err)
}

// getBytesPool returns (creates if needed) a pool we store byte buffers in.
func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) {
	p.bytesPoolMu.Lock()
	defer p.bytesPoolMu.Unlock()

	if p.bytesPool == nil {
		p.bytesPool = &sync.Pool{
			New: func() interface{} {
				b := make([]byte, dns.MaxMsgSize)

				return &b
			},
		}
	}

	return p.bytesPool
}

// getConnection opens or returns an existing quic.Connection. useCached
// argument controls whether we should try to use the existing cached
// connection.  If it is false, we will forcibly create a new connection and
// close the existing one if needed.
func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) {
	var conn quic.Connection
	p.connMu.RLock()
	conn = p.conn
	if conn != nil && useCached {
		p.connMu.RUnlock()

		return conn, nil
	}
	if conn != nil {
		// we're recreating the connection, let's create a new one.
		_ = conn.CloseWithError(QUICCodeNoError, "")
	}
	p.connMu.RUnlock()

	p.connMu.Lock()
	defer p.connMu.Unlock()

	var err error
	conn, err = p.openConnection()
	if err != nil {
		return nil, err
	}
	p.conn = conn

	return conn, nil
}

// hasConnection returns true if there's an active QUIC connection.
func (p *dnsOverQUIC) hasConnection() (ok bool) {
	p.connMu.Lock()
	defer p.connMu.Unlock()

	return p.conn != nil
}

// getQUICConfig returns the QUIC config in a thread-safe manner.  Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverQUIC) getQUICConfig() (c *quic.Config) {
	p.quicConfigMu.Lock()
	defer p.quicConfigMu.Unlock()

	return p.quicConfig
}

// resetQUICConfig re-creates the tokens store as we may need to use a new one
// if we failed to connect.
func (p *dnsOverQUIC) resetQUICConfig() {
	p.quicConfigMu.Lock()
	defer p.quicConfigMu.Unlock()

	p.quicConfig = p.quicConfig.Clone()
	p.quicConfig.TokenStore = newQUICTokenStore()
}

// openStream opens a new QUIC stream for the specified connection.
func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) {
	ctx, cancel := p.withDeadline(context.Background())
	defer cancel()

	stream, err := conn.OpenStreamSync(ctx)
	if err == nil {
		return stream, nil
	}

	// We can get here if the old QUIC connection is not valid anymore.  We
	// should try to re-create the connection again in this case.
	newConn, err := p.getConnection(false)
	if err != nil {
		return nil, err
	}
	// Open a new stream.
	return newConn.OpenStreamSync(ctx)
}

// openConnection opens a new QUIC connection.
func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
	dialContext, err := p.getDialer()
	if err != nil {
		return nil, fmt.Errorf("failed to bootstrap QUIC connection: %w", err)
	}

	// we're using bootstrapped address instead of what's passed to the function
	// it does not create an actual connection, but it helps us determine
	// what IP is actually reachable (when there're v4/v6 addresses).
	rawConn, err := dialContext(context.Background(), "udp", "")
	if err != nil {
		return nil, fmt.Errorf("failed to open a QUIC connection: %w", err)
	}
	// It's never actually used
	_ = rawConn.Close()

	udpConn, ok := rawConn.(*net.UDPConn)
	if !ok {
		return nil, fmt.Errorf("failed to open connection to %s", p.addr)
	}

	addr := udpConn.RemoteAddr().String()

	ctx, cancel := p.withDeadline(context.Background())
	defer cancel()

	conn, err = quic.DialAddrEarly(ctx, addr, p.tlsConf.Clone(), p.getQUICConfig())
	if err != nil {
		return nil, fmt.Errorf("opening quic connection to %s: %w", p.addr, err)
	}

	return conn, nil
}

// closeConnWithError closes the active connection with error to make sure that
// new queries were processed in another connection.  We can do that in the case
// of a fatal error.
func (p *dnsOverQUIC) closeConnWithError(err error) {
	p.connMu.Lock()
	defer p.connMu.Unlock()

	if p.conn == nil {
		// Do nothing, there's no active conn anyways.
		return
	}

	code := QUICCodeNoError
	if err != nil {
		code = QUICCodeInternalError
	}

	if errors.Is(err, quic.Err0RTTRejected) {
		// Reset the TokenStore only if 0-RTT was rejected.
		p.resetQUICConfig()
	}

	err = p.conn.CloseWithError(code, "")
	if err != nil {
		log.Error("failed to close the conn: %v", err)
	}
	p.conn = nil
}

// readMsg reads the incoming DNS message from the QUIC stream.
func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *dns.Msg, err error) {
	pool := p.getBytesPool()
	bufPtr := pool.Get().(*[]byte)

	defer pool.Put(bufPtr)

	respBuf := *bufPtr
	n, err := stream.Read(respBuf)
	if err != nil && n == 0 {
		return nil, fmt.Errorf("reading response from %s: %w", p.addr, err)
	}

	// All DNS messages (queries and responses) sent over DoQ connections MUST
	// be encoded as a 2-octet length field followed by the message content as
	// specified in [RFC1035].
	// IMPORTANT: Note, that we ignore this prefix here as this implementation
	// does not support receiving multiple messages over a single connection.
	m = new(dns.Msg)
	err = m.Unpack(respBuf[2:])
	if err != nil {
		return nil, fmt.Errorf("unpacking response from %s: %w", p.addr, err)
	}

	return m, nil
}

// newQUICTokenStore creates a new quic.TokenStore that is necessary to have
// in order to benefit from 0-RTT.
func newQUICTokenStore() (s quic.TokenStore) {
	// You can read more on address validation here:
	// https://datatracker.ietf.org/doc/html/rfc9000#section-8.1
	// Setting maxOrigins to 1 and tokensPerOrigin to 10 assuming that this is
	// more than enough for the way we use it (one connection per upstream).
	return quic.NewLRUTokenStore(1, 10)
}

// isQUICRetryError checks the error and determines whether it may signal that
// we should re-create the QUIC connection.  This requirement is caused by
// quic-go issues, see the comments inside this function.
// TODO(ameshkov): re-test when updating quic-go.
func isQUICRetryError(err error) (ok bool) {
	var qAppErr *quic.ApplicationError
	if errors.As(err, &qAppErr) {
		// Error code 0 is often returned when the server has been restarted,
		// and we try to use the same connection on the client-side.
		// http3.ErrCodeNoError may be used by an HTTP/3 server when closing
		// an idle connection.  These connections are not immediately closed
		// by the HTTP client so this case should be handled.
		if qAppErr.ErrorCode == 0 ||
			qAppErr.ErrorCode == quic.ApplicationErrorCode(http3.ErrCodeNoError) {
			return true
		}
	}

	var qIdleErr *quic.IdleTimeoutError
	if errors.As(err, &qIdleErr) {
		// This error means that the connection was closed due to being idle.
		// In this case we should forcibly re-create the QUIC connection.
		// Reproducing is rather simple, stop the server and wait for 30 seconds
		// then try to send another request via the same upstream.
		return true
	}

	var resetErr *quic.StatelessResetError
	if errors.As(err, &resetErr) {
		// A stateless reset is sent when a server receives a QUIC packet that
		// it doesn't know how to decrypt.  For instance, it may happen when
		// the server was recently rebooted.  We should reconnect and try again
		// in this case.
		return true
	}

	var qTransportError *quic.TransportError
	if errors.As(err, &qTransportError) && qTransportError.ErrorCode == quic.NoError {
		// A transport error with the NO_ERROR error code could be sent by the
		// server when it considers that it's time to close the connection.
		// For example, Google DNS eventually closes an active connection with
		// the NO_ERROR code and "Connection max age expired" message:
		// https://github.com/AdguardTeam/dnsproxy/issues/283
		return true
	}

	if errors.Is(err, quic.Err0RTTRejected) {
		// This error happens when we try to establish a 0-RTT connection with
		// a token the server is no more aware of.  This can be reproduced by
		// restarting the QUIC server (it will clear its tokens cache).  The
		// next connection attempt will return this error until the client's
		// tokens cache is purged.
		return true
	}

	return false
}

func (p *dnsOverQUIC) withDeadline(
	parent context.Context,
) (ctx context.Context, cancel context.CancelFunc) {
	ctx, cancel = parent, func() {}
	if p.timeout > 0 {
		ctx, cancel = context.WithDeadline(ctx, time.Now().Add(p.timeout))
	}

	return ctx, cancel
}
07070100000083000081A4000000000000000000000001650C5921000021DC000000000000000000000000000000000000002600000000dnsproxy-0.55.0/upstream/quic_test.gopackage upstream

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/binary"
	"fmt"
	"io"
	"net"
	"strconv"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/logging"
	"github.com/stretchr/testify/require"
)

func TestUpstreamDoQ(t *testing.T) {
	srv := startDoQServer(t, 0)
	t.Cleanup(srv.Shutdown)

	address := fmt.Sprintf("quic://%s", srv.addr)
	var lastState tls.ConnectionState
	opts := &Options{
		InsecureSkipVerify: true,
		VerifyConnection: func(state tls.ConnectionState) error {
			lastState = state

			return nil
		},
	}
	u, err := AddressToUpstream(address, opts)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	uq := u.(*dnsOverQUIC)
	var conn quic.Connection

	// Test that it responds properly
	for i := 0; i < 10; i++ {
		checkUpstream(t, u, address)

		if conn == nil {
			conn = uq.conn
		} else {
			// This way we test that the conn is properly reused.
			require.Equal(t, conn, uq.conn)
		}
	}

	// Close the connection (make sure that we re-establish the connection).
	_ = conn.CloseWithError(quic.ApplicationErrorCode(0), "")

	// Try to establish it again.
	checkUpstream(t, u, address)

	// Make sure that the session has been resumed.
	require.True(t, lastState.DidResume)

	// Re-create the upstream to make the test check initialization and
	// check it for race conditions.
	u, err = AddressToUpstream(address, opts)
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	checkRaceCondition(u)
}

func TestUpstreamDoQ_serverRestart(t *testing.T) {
	// Run the first server instance.
	srv := startDoQServer(t, 0)

	// Create a DNS-over-QUIC upstream.
	address := fmt.Sprintf("quic://%s", srv.addr)
	u, err := AddressToUpstream(address, &Options{InsecureSkipVerify: true, Timeout: time.Second})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	// Test that the upstream works properly.
	checkUpstream(t, u, address)

	// Now let's restart the server on the same address.
	_, portStr, err := net.SplitHostPort(srv.addr)
	require.NoError(t, err)
	port, err := strconv.Atoi(portStr)
	require.NoError(t, err)

	// Shutdown the first server.
	srv.Shutdown()

	// Start the new one on the same port.
	srv = startDoQServer(t, port)

	// Check that everything works after restart.
	checkUpstream(t, u, address)

	// Stop the server again.
	srv.Shutdown()

	// Now try to send a message and make sure that it returns an error.
	_, err = u.Exchange(createTestMessage())
	require.Error(t, err)

	// Start the server one more time.
	srv = startDoQServer(t, port)
	defer srv.Shutdown()

	// Check that everything works after the second restart.
	checkUpstream(t, u, address)
}

func TestUpstreamDoQ_0RTT(t *testing.T) {
	srv := startDoQServer(t, 0)
	t.Cleanup(srv.Shutdown)

	tracer := &quicTracer{}
	address := fmt.Sprintf("quic://%s", srv.addr)
	u, err := AddressToUpstream(address, &Options{
		InsecureSkipVerify: true,
		QUICTracer:         tracer.TracerForConnection,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	uq := u.(*dnsOverQUIC)
	req := createTestMessage()

	// Trigger connection to a QUIC server.
	resp, err := uq.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, resp)

	// Close the active connection to make sure we'll reconnect.
	func() {
		uq.connMu.Lock()
		defer uq.connMu.Unlock()

		err = uq.conn.CloseWithError(QUICCodeNoError, "")
		require.NoError(t, err)

		uq.conn = nil
	}()

	// Trigger second connection.
	resp, err = uq.Exchange(req)
	require.NoError(t, err)
	requireResponse(t, req, resp)

	// Check traced connections info.
	conns := tracer.getConnectionsInfo()
	require.Len(t, conns, 2)

	// Examine the first connection (no 0-Rtt there).
	require.False(t, conns[0].is0RTT())

	// Examine the second connection (the one that used 0-Rtt).
	require.True(t, conns[1].is0RTT())
}

// testDoHServer is an instance of a test DNS-over-QUIC server.
type testDoQServer struct {
	// tlsConfig is the TLS configuration that is used for this server.
	tlsConfig *tls.Config

	// rootCAs is the pool with root certificates used by the test server.
	rootCAs *x509.CertPool

	// listener is the QUIC connections listener.
	listener *quic.EarlyListener

	// addr is the address that this server listens to.
	addr string
}

// Shutdown stops the test server.
func (s *testDoQServer) Shutdown() {
	_ = s.listener.Close()
}

// Serve serves DoQ requests.
func (s *testDoQServer) Serve() {
	for {
		conn, err := s.listener.Accept(context.Background())
		if err == quic.ErrServerClosed {
			// Finish serving on ErrServerClosed error.
			return
		}

		if err != nil {
			log.Debug("error while accepting a new connection: %v", err)
		}

		go s.handleQUICConnection(conn)
	}
}

// handleQUICConnection handles incoming QUIC connection.
func (s *testDoQServer) handleQUICConnection(conn quic.EarlyConnection) {
	for {
		stream, err := conn.AcceptStream(context.Background())
		if err != nil {
			_ = conn.CloseWithError(QUICCodeNoError, "")

			return
		}

		go func() {
			qErr := s.handleQUICStream(stream)
			if qErr != nil {
				_ = conn.CloseWithError(QUICCodeNoError, "")
			}
		}()
	}
}

// handleQUICStream handles new QUIC streams, reads DNS messages and responds to
// them.
func (s *testDoQServer) handleQUICStream(stream quic.Stream) (err error) {
	defer log.OnCloserError(stream, log.DEBUG)

	buf := make([]byte, dns.MaxMsgSize+2)
	_, err = stream.Read(buf)
	if err != nil && err != io.EOF {
		return err
	}

	req := &dns.Msg{}
	packetLen := binary.BigEndian.Uint16(buf[:2])
	err = req.Unpack(buf[2 : packetLen+2])
	if err != nil {
		return err
	}

	resp := respondToTestMessage(req)

	buf, err = resp.Pack()
	if err != nil {
		return err
	}

	buf = proxyutil.AddPrefix(buf)
	_, err = stream.Write(buf)

	return err
}

// startDoQServer starts a test DoQ server.
func startDoQServer(t *testing.T, port int) (s *testDoQServer) {
	tlsConfig, rootCAs := createServerTLSConfig(t, "127.0.0.1")
	tlsConfig.NextProtos = []string{NextProtoDQ}

	listen, err := quic.ListenAddrEarly(
		fmt.Sprintf("127.0.0.1:%d", port),
		tlsConfig,
		&quic.Config{
			// Necessary for 0-RTT.
			RequireAddressValidation: func(net.Addr) (ok bool) {
				return false
			},
			Allow0RTT: true,
		},
	)
	require.NoError(t, err)

	s = &testDoQServer{
		addr:      listen.Addr().String(),
		tlsConfig: tlsConfig,
		rootCAs:   rootCAs,
		listener:  listen,
	}

	go s.Serve()

	return s
}

// quicTracer implements the logging.Tracer interface.
type quicTracer struct {
	logging.NullTracer
	tracers []*quicConnTracer

	// mu protects fields of *quicTracer and also protects fields of every
	// nested *quicConnTracer.
	mu sync.Mutex
}

// type check
var _ logging.Tracer = (*quicTracer)(nil)

// TracerForConnection implements the logging.Tracer interface for *quicTracer.
func (q *quicTracer) TracerForConnection(
	_ context.Context,
	_ logging.Perspective,
	odcid logging.ConnectionID,
) (connTracer logging.ConnectionTracer) {
	q.mu.Lock()
	defer q.mu.Unlock()

	tracer := &quicConnTracer{id: odcid, parent: q}
	q.tracers = append(q.tracers, tracer)

	return tracer
}

// connInfo contains information about packets that we've logged.
type connInfo struct {
	packets []logging.Header
	id      logging.ConnectionID
}

// is0RTT returns true if this connection's packets contain 0-RTT packets.
func (c *connInfo) is0RTT() (ok bool) {
	for _, packet := range c.packets {
		hdr := packet
		packetType := logging.PacketTypeFromHeader(&hdr)
		if packetType == logging.PacketType0RTT {
			return true
		}
	}

	return false
}

// getConnectionsInfo returns the traced connections' information.
func (q *quicTracer) getConnectionsInfo() (conns []connInfo) {
	q.mu.Lock()
	defer q.mu.Unlock()

	for _, tracer := range q.tracers {
		conns = append(conns, connInfo{
			id:      tracer.id,
			packets: tracer.packets,
		})
	}

	return conns
}

// quicConnTracer implements the logging.ConnectionTracer interface.
type quicConnTracer struct {
	logging.NullConnectionTracer
	parent  *quicTracer
	packets []logging.Header
	id      logging.ConnectionID
}

// type check
var _ logging.ConnectionTracer = (*quicConnTracer)(nil)

// SentLongHeaderPacket implements the logging.ConnectionTracer interface for
// *quicConnTracer.
func (q *quicConnTracer) SentLongHeaderPacket(
	hdr *logging.ExtendedHeader,
	_ logging.ByteCount,
	_ *logging.AckFrame,
	_ []logging.Frame,
) {
	q.parent.mu.Lock()
	defer q.parent.mu.Unlock()

	q.packets = append(q.packets, hdr.Header)
}
07070100000084000081A4000000000000000000000001650C5921000032A0000000000000000000000000000000000000002500000000dnsproxy-0.55.0/upstream/upstream.go// Package upstream implements DNS clients for all known DNS encryption
// protocols.
package upstream

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"net"
	"net/netip"
	"net/url"
	"strconv"
	"strings"
	"sync/atomic"
	"time"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/ameshkov/dnscrypt/v2"
	"github.com/ameshkov/dnsstamps"
	"github.com/miekg/dns"
	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/logging"
)

// Upstream is an interface for a DNS resolver.
type Upstream interface {
	// Exchange sends the DNS query req to this upstream and returns the
	// response that has been received or an error if something went wrong.
	Exchange(req *dns.Msg) (*dns.Msg, error)

	// Address returns the address of the upstream DNS resolver.
	Address() string

	// Closer used to close the upstreams properly.  Exchange shouldn't be
	// called after calling Close.
	io.Closer
}

// QUICTraceFunc is a function that returns a [logging.ConnectionTracer]
// specific for a given role and connection ID.
type QUICTraceFunc func(
	ctx context.Context,
	role logging.Perspective,
	connID quic.ConnectionID,
) (tracer logging.ConnectionTracer)

// Options for AddressToUpstream func.  With these options we can configure the
// upstream properties.
type Options struct {
	// VerifyServerCertificate is used to set the VerifyPeerCertificate property
	// of the *tls.Config for DNS-over-HTTPS, DNS-over-QUIC, and DNS-over-TLS.
	VerifyServerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error

	// VerifyConnection is used to set the VerifyConnection property
	// of the *tls.Config for DNS-over-HTTPS, DNS-over-QUIC, and DNS-over-TLS.
	VerifyConnection func(state tls.ConnectionState) error

	// VerifyDNSCryptCertificate is the callback the DNSCrypt server certificate
	// will be passed to.  It's called in dnsCrypt.exchangeDNSCrypt.
	// Upstream.Exchange method returns any error caused by it.
	VerifyDNSCryptCertificate func(cert *dnscrypt.Cert) error

	// QUICTracer is an optional callback that allows tracing every QUIC
	// connection and logging every packet that goes through.
	QUICTracer QUICTraceFunc

	// RootCAs is the CertPool that must be used by all upstreams.  Redefining
	// RootCAs makes sense on iOS to overcome the 15MB memory limit of the
	// NEPacketTunnelProvider.
	RootCAs *x509.CertPool

	// CipherSuites is a custom list of TLSv1.2 ciphers.
	CipherSuites []uint16

	// Bootstrap is a list of DNS servers to be used to resolve
	// DNS-over-HTTPS/DNS-over-TLS hostnames.  Plain DNS, DNSCrypt, or
	// DNS-over-HTTPS/DNS-over-TLS with IP addresses (not hostnames) could be
	// used.
	Bootstrap []string

	// List of IP addresses of the upstream DNS server.  If not empty, bootstrap
	// DNS servers won't be used at all.
	ServerIPAddrs []net.IP

	// HTTPVersions is a list of HTTP versions that should be supported by the
	// DNS-over-HTTPS client.  If not set, HTTP/1.1 and HTTP/2 will be used.
	HTTPVersions []HTTPVersion

	// Timeout is the default upstream timeout.  It's also used as a timeout for
	// bootstrap DNS requests.  Zero value disables the timeout.
	Timeout time.Duration

	// InsecureSkipVerify disables verifying the server's certificate.
	InsecureSkipVerify bool

	// PreferIPv6 tells the bootstrapper to prefer IPv6 addresses for an
	// upstream.
	PreferIPv6 bool
}

// Clone copies o to a new struct.  Note, that this is not a deep clone.
func (o *Options) Clone() (clone *Options) {
	return &Options{
		Bootstrap:                 o.Bootstrap,
		Timeout:                   o.Timeout,
		ServerIPAddrs:             o.ServerIPAddrs,
		HTTPVersions:              o.HTTPVersions,
		VerifyServerCertificate:   o.VerifyServerCertificate,
		VerifyConnection:          o.VerifyConnection,
		VerifyDNSCryptCertificate: o.VerifyDNSCryptCertificate,
		InsecureSkipVerify:        o.InsecureSkipVerify,
		PreferIPv6:                o.PreferIPv6,
		QUICTracer:                o.QUICTracer,
		RootCAs:                   o.RootCAs,
		CipherSuites:              o.CipherSuites,
	}
}

// HTTPVersion is an enumeration of the HTTP versions that we support.  Values
// that we use in this enumeration are also used as ALPN values.
type HTTPVersion string

const (
	// HTTPVersion11 is HTTP/1.1.
	HTTPVersion11 HTTPVersion = "http/1.1"
	// HTTPVersion2 is HTTP/2.
	HTTPVersion2 HTTPVersion = "h2"
	// HTTPVersion3 is HTTP/3.
	HTTPVersion3 HTTPVersion = "h3"
)

// DefaultHTTPVersions is the list of HTTPVersion that we use by default in
// the DNS-over-HTTPS client.
var DefaultHTTPVersions = []HTTPVersion{HTTPVersion11, HTTPVersion2}

const (
	// defaultPortPlain is the default port for plain DNS.
	defaultPortPlain = 53

	// defaultPortDoH is the default port for DNS-over-HTTPS.
	defaultPortDoH = 443

	// defaultPortDoT is the default port for DNS-over-TLS.
	defaultPortDoT = 853

	// defaultPortDoQ is the default port for DNS-over-QUIC.  Prior to version
	// -10 of the draft experiments were directed to use ports 8853, 784.
	//
	// See https://www.rfc-editor.org/rfc/rfc9250.html#name-port-selection.
	defaultPortDoQ = 853
)

// AddressToUpstream converts addr to an Upstream using the specified options.
// addr can be either a URL, or a plain address, either a domain name or an IP.
//
//   - udp://5.3.5.3:53 or 5.3.5.3:53 for plain DNS using IP address;
//   - udp://name.server:53 or name.server:53 for plain DNS using domain name;
//   - tcp://5.3.5.3:53 for plain DNS-over-TCP using IP address;
//   - tcp://name.server:53 for plain DNS-over-TCP using domain name;
//   - tls://5.3.5.3:853 for DNS-over-TLS using IP address;
//   - tls://name.server:853 for DNS-over-TLS using domain name;
//   - https://5.3.5.3:443/dns-query for DNS-over-HTTPS using IP address;
//   - https://name.server:443/dns-query for DNS-over-HTTPS using domain name;
//   - quic://5.3.5.3:853 for DNS-over-QUIC using IP address;
//   - quic://name.server:853 for DNS-over-QUIC using domain name;
//   - h3://dns.google for DNS-over-HTTPS that only works with HTTP/3;
//   - sdns://... for DNS stamp, see https://dnscrypt.info/stamps-specifications.
//
// If addr doesn't have port specified, the default port of the appropriate
// protocol will be used.
//
// opts are applied to the u and shouldn't be modified afterwards, nil value is
// valid.
//
// TODO(e.burkov):  Clone opts?
func AddressToUpstream(addr string, opts *Options) (u Upstream, err error) {
	if opts == nil {
		opts = &Options{}
	}

	var uu *url.URL
	if strings.Contains(addr, "://") {
		// Parse as URL.
		uu, err = url.Parse(addr)
		if err != nil {
			return nil, fmt.Errorf("failed to parse %s: %w", addr, err)
		}
	} else {
		// Probably, plain UDP upstream defined by address or address:port.
		_, port, splitErr := net.SplitHostPort(addr)
		if splitErr == nil {
			// Validate port.
			_, err = strconv.ParseUint(port, 10, 16)
			if err != nil {
				return nil, fmt.Errorf("invalid address %s: %w", addr, err)
			}
		}

		uu = &url.URL{
			Scheme: "udp",
			Host:   addr,
		}
	}

	return urlToUpstream(uu, opts)
}

// urlToUpstream converts uu to an Upstream using opts.
func urlToUpstream(uu *url.URL, opts *Options) (u Upstream, err error) {
	switch sch := uu.Scheme; sch {
	case "sdns":
		return parseStamp(uu, opts)
	case "udp", "tcp":
		return newPlain(uu, opts)
	case "quic":
		return newDoQ(uu, opts)
	case "tls":
		return newDoT(uu, opts)
	case "h3", "https":
		return newDoH(uu, opts)
	default:
		return nil, fmt.Errorf("unsupported url scheme: %s", sch)
	}
}

// parseStamp converts a DNS stamp to an Upstream.
func parseStamp(upsURL *url.URL, opts *Options) (u Upstream, err error) {
	stamp, err := dnsstamps.NewServerStampFromString(upsURL.String())
	if err != nil {
		return nil, fmt.Errorf("failed to parse %s: %w", upsURL, err)
	}

	// TODO(e.burkov):  Port?
	if stamp.ServerAddrStr != "" {
		host, _, sErr := netutil.SplitHostPort(stamp.ServerAddrStr)
		if sErr != nil {
			host = stamp.ServerAddrStr
		}

		// Parse and add to options.
		ip := net.ParseIP(host)
		if ip == nil {
			return nil, fmt.Errorf("invalid server stamp address %s", stamp.ServerAddrStr)
		}

		// TODO(e.burkov):  Append?
		opts.ServerIPAddrs = []net.IP{ip}
	}

	switch stamp.Proto {
	case dnsstamps.StampProtoTypePlain:
		return newPlain(&url.URL{Scheme: "udp", Host: stamp.ServerAddrStr}, opts)
	case dnsstamps.StampProtoTypeDNSCrypt:
		return newDNSCrypt(upsURL, opts), nil
	case dnsstamps.StampProtoTypeDoH:
		return newDoH(&url.URL{Scheme: "https", Host: stamp.ProviderName, Path: stamp.Path}, opts)
	case dnsstamps.StampProtoTypeDoQ:
		return newDoQ(&url.URL{Scheme: "quic", Host: stamp.ProviderName, Path: stamp.Path}, opts)
	case dnsstamps.StampProtoTypeTLS:
		return newDoT(&url.URL{Scheme: "tls", Host: stamp.ProviderName}, opts)
	default:
		return nil, fmt.Errorf("unsupported stamp protocol %s", &stamp.Proto)
	}
}

// addPort appends port to u if it's absent.
func addPort(u *url.URL, port uint16) {
	if u != nil {
		_, _, err := net.SplitHostPort(u.Host)
		if err != nil {
			u.Host = netutil.JoinHostPort(u.Host, port)

			return
		}
	}
}

// logBegin logs the start of DNS request resolution.  It should be called right
// before dialing the connection to the upstream.  n is the [network] that will
// be used to send the request.
func logBegin(upstreamAddress string, n network, req *dns.Msg) {
	qtype := ""
	target := ""
	if len(req.Question) != 0 {
		qtype = dns.Type(req.Question[0].Qtype).String()
		target = req.Question[0].Name
	}
	log.Debug("%s: sending request over %s: %s %s", upstreamAddress, n, qtype, target)
}

// Write to log about the result of DNS request
func logFinish(upstreamAddress string, n network, err error) {
	status := "ok"
	if err != nil {
		status = err.Error()
	}
	log.Debug("%s: response received over %s: %s", upstreamAddress, n, status)
}

// DialerInitializer returns the handler that it creates.  All the subsequent
// calls to it, except the first one, will return the same handler so that
// resolving will be performed only once.
type DialerInitializer func() (handler bootstrap.DialHandler, err error)

// newDialerInitializer creates an initializer of the dialer that will dial the
// addresses resolved from u using opts.
func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer, err error) {
	host, port, err := netutil.SplitHostPort(u.Host)
	if err != nil {
		return nil, fmt.Errorf("invalid address: %s: %w", u.Host, err)
	}

	if addrsLen := len(opts.ServerIPAddrs); addrsLen > 0 {
		// Don't resolve the addresses of the server since those from the
		// options should be used.
		addrs := make([]string, 0, addrsLen)
		for _, addr := range opts.ServerIPAddrs {
			addrs = append(addrs, netutil.JoinHostPort(addr.String(), port))
		}

		handler := bootstrap.NewDialContext(opts.Timeout, addrs...)

		return func() (bootstrap.DialHandler, error) { return handler, nil }, nil
	} else if _, err = netip.ParseAddr(host); err == nil {
		// Don't resolve the address of the server since it's already an IP.
		handler := bootstrap.NewDialContext(opts.Timeout, u.Host)

		return func() (bootstrap.DialHandler, error) { return handler, nil }, nil
	}

	resolvers, err := newResolvers(opts)
	if err != nil {
		// Don't wrap the error since it's informative enough as is.
		return nil, err
	}

	var dialHandler atomic.Value
	di = func() (h bootstrap.DialHandler, resErr error) {
		// Check if the dial handler has already been created.
		h, ok := dialHandler.Load().(bootstrap.DialHandler)
		if ok {
			return h, nil
		}

		// TODO(e.burkov):  It may appear that several exchanges will try to
		// resolve the upstream hostname at the same time.  Currently, the last
		// successful value will be stored in dialHandler, but ideally we should
		// resolve only once.
		h, resolveErr := bootstrap.ResolveDialContext(u, opts.Timeout, resolvers, opts.PreferIPv6)
		if resolveErr != nil {
			return nil, fmt.Errorf("creating dial handler: %w", resolveErr)
		}

		if !dialHandler.CompareAndSwap(nil, h) {
			return dialHandler.Load().(bootstrap.DialHandler), nil
		}

		return h, nil
	}

	return di, nil
}

// newResolvers prepares resolvers for bootstrapping.  If opts.Bootstrap is
// empty, the only new [net.Resolver] will be returned.  Otherwise, the it will
// be added for each occurrence of an empty string in [Options.Bootstrap].
func newResolvers(opts *Options) (resolvers []Resolver, err error) {
	bootstraps := opts.Bootstrap
	if len(bootstraps) == 0 {
		return []Resolver{&net.Resolver{}}, nil
	}

	resolvers = make([]Resolver, 0, len(bootstraps))
	for _, boot := range bootstraps {
		if boot == "" {
			resolvers = append(resolvers, &net.Resolver{})

			continue
		}

		r, rErr := NewUpstreamResolver(boot, opts)
		if rErr != nil {
			return nil, fmt.Errorf("preparing bootstrap resolver: %w", rErr)
		}

		resolvers = append(resolvers, r)
	}

	return resolvers, nil
}
07070100000085000081A4000000000000000000000001650C592100004200000000000000000000000000000000000000002A00000000dnsproxy-0.55.0/upstream/upstream_test.gopackage upstream

import (
	"crypto/ecdsa"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"fmt"
	"io"
	"math/big"
	"net"
	"net/netip"
	"net/url"
	"os"
	"sync"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/log"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/ameshkov/dnsstamps"
	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// TODO(ameshkov): make tests here not depend on external servers.

func TestMain(m *testing.M) {
	// Disable logging in tests.
	log.SetOutput(io.Discard)

	os.Exit(m.Run())
}

func TestUpstream_bootstrapTimeout(t *testing.T) {
	const (
		timeout = 100 * time.Millisecond
		count   = 10
	)

	// Test listener that never accepts connections to emulate faulty bootstrap.
	udpListener, err := net.ListenPacket("udp", "127.0.0.1:0")
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, udpListener.Close)

	// Create an upstream that uses this faulty bootstrap.
	u, err := AddressToUpstream("tls://random-domain-name", &Options{
		Bootstrap: []string{udpListener.LocalAddr().String()},
		Timeout:   timeout,
	})
	require.NoError(t, err)
	testutil.CleanupAndRequireSuccess(t, u.Close)

	ch := make(chan int, count)
	abort := make(chan string, 1)
	for i := 0; i < count; i++ {
		go func(idx int) {
			t.Logf("Start %d", idx)
			req := createTestMessage()

			start := time.Now()
			_, rErr := u.Exchange(req)
			elapsed := time.Since(start)

			if rErr == nil {
				// Must not happen since bootstrap server cannot work.
				abort <- fmt.Sprintf("the upstream must have timed out: %v", rErr)
			}

			// Check that the test didn't take too much time compared to the
			// configured timeout.  The actual elapsed time may be higher than
			// the timeout due to the execution environment, 3 is an arbitrarily
			// chosen multiplier to account for that.
			if elapsed > 3*timeout {
				abort <- fmt.Sprintf(
					"exchange took more time than the configured timeout: %s",
					elapsed,
				)
			}
			t.Logf("Finished %d", idx)
			ch <- idx
		}(i)
	}

	for i := 0; i < count; i++ {
		select {
		case res := <-ch:
			t.Logf("Got result from %d", res)
		case msg := <-abort:
			t.Fatalf("Aborted from the goroutine: %s", msg)
		case <-time.After(timeout * 10):
			t.Fatalf("No response in time")
		}
	}
}

func TestUpstreams(t *testing.T) {
	upstreams := []struct {
		address   string
		bootstrap []string
	}{{
		address:   "8.8.8.8:53",
		bootstrap: []string{"8.8.8.8:53"},
	}, {
		address:   "1.1.1.1",
		bootstrap: []string{},
	}, {
		address:   "1.1.1.1",
		bootstrap: []string{"1.0.0.1"},
	}, {
		address:   "tcp://1.1.1.1:53",
		bootstrap: []string{},
	}, {
		address:   "94.140.14.14:5353",
		bootstrap: []string{},
	}, {
		address:   "tls://1.1.1.1",
		bootstrap: []string{},
	}, {
		address:   "tls://9.9.9.9:853",
		bootstrap: []string{},
	}, {
		address:   "tls://dns.adguard.com",
		bootstrap: []string{"8.8.8.8:53"},
	}, {
		address:   "tls://dns.adguard.com:853",
		bootstrap: []string{"8.8.8.8:53"},
	}, {
		address:   "tls://dns.adguard.com:853",
		bootstrap: []string{"8.8.8.8"},
	}, {
		address:   "tls://one.one.one.one",
		bootstrap: []string{},
	}, {
		address:   "https://1dot1dot1dot1.cloudflare-dns.com/dns-query",
		bootstrap: []string{"8.8.8.8:53"},
	}, {
		address:   "https://dns.google/dns-query",
		bootstrap: []string{},
	}, {
		address:   "https://doh.opendns.com/dns-query",
		bootstrap: []string{},
	}, {
		// AdGuard DNS (DNSCrypt)
		address:   "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
		bootstrap: []string{},
	}, {
		// AdGuard Family (DNSCrypt)
		address:   "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMjo1NDQzILgxXdexS27jIKRw3C7Wsao5jMnlhvhdRUXWuMm1AFq6ITIuZG5zY3J5cHQuZmFtaWx5Lm5zMS5hZGd1YXJkLmNvbQ",
		bootstrap: []string{"8.8.8.8"},
	}, {
		// Cloudflare DNS (DNS-over-HTTPS)
		address:   "sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk",
		bootstrap: []string{"8.8.8.8:53"},
	}, {
		// Google (Plain)
		address:   "sdns://AAcAAAAAAAAABzguOC44Ljg",
		bootstrap: []string{},
	}, {
		// AdGuard DNS (DNS-over-TLS)
		address:   "sdns://AwAAAAAAAAAAAAAPZG5zLmFkZ3VhcmQuY29t",
		bootstrap: []string{"8.8.8.8:53"},
	}, {
		// AdGuard DNS (DNS-over-QUIC)
		address:   "sdns://BAcAAAAAAAAAAAAXZG5zLmFkZ3VhcmQtZG5zLmNvbTo3ODQ",
		bootstrap: []string{"8.8.8.8:53"},
	}, {
		// Cloudflare DNS (DNS-over-HTTPS)
		address:   "https://1.1.1.1/dns-query",
		bootstrap: []string{},
	}, {
		// AdGuard DNS (DNS-over-QUIC)
		address:   "quic://dns.adguard-dns.com",
		bootstrap: []string{"1.1.1.1:53"},
	}, {
		// Google DNS (HTTP3)
		address:   "h3://dns.google/dns-query",
		bootstrap: []string{},
	}}
	for _, test := range upstreams {
		t.Run(test.address, func(t *testing.T) {
			u, err := AddressToUpstream(
				test.address,
				&Options{Bootstrap: test.bootstrap, Timeout: timeout},
			)
			require.NoErrorf(t, err, "failed to generate upstream from address %s", test.address)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			checkUpstream(t, u, test.address)
		})
	}
}

func TestAddressToUpstream(t *testing.T) {
	opt := &Options{Bootstrap: []string{"1.1.1.1"}}

	testCases := []struct {
		addr string
		opt  *Options
		want string
	}{{
		addr: "1.1.1.1",
		opt:  nil,
		want: "1.1.1.1:53",
	}, {
		addr: "one.one.one.one",
		opt:  nil,
		want: "one.one.one.one:53",
	}, {
		addr: "udp://one.one.one.one",
		opt:  nil,
		want: "one.one.one.one:53",
	}, {
		addr: "tcp://one.one.one.one",
		opt:  opt,
		want: "tcp://one.one.one.one:53",
	}, {
		addr: "tls://one.one.one.one",
		opt:  opt,
		want: "tls://one.one.one.one:853",
	}, {
		addr: "https://one.one.one.one",
		opt:  opt,
		want: "https://one.one.one.one:443",
	}, {
		addr: "h3://one.one.one.one",
		opt:  opt,
		want: "https://one.one.one.one:443",
	}}

	for _, tc := range testCases {
		t.Run(tc.addr, func(t *testing.T) {
			u, err := AddressToUpstream(tc.addr, tc.opt)
			require.NoError(t, err)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			assert.Equal(t, tc.want, u.Address())
		})
	}
}

func TestAddressToUpstream_bads(t *testing.T) {
	testCases := []struct {
		addr       string
		wantErrMsg string
	}{{
		addr:       "asdf://1.1.1.1",
		wantErrMsg: "unsupported url scheme: asdf",
	}, {
		addr: "12345.1.1.1:1234567",
		wantErrMsg: `invalid address 12345.1.1.1:1234567: ` +
			`strconv.ParseUint: parsing "1234567": value out of range`,
	}, {
		addr: ":1234567",
		wantErrMsg: `invalid address :1234567: ` +
			`strconv.ParseUint: parsing "1234567": value out of range`,
	}, {
		addr:       "host:",
		wantErrMsg: `invalid address host:: strconv.ParseUint: parsing "": invalid syntax`,
	}}

	for _, tc := range testCases {
		t.Run(tc.addr, func(t *testing.T) {
			_, err := AddressToUpstream(tc.addr, nil)
			testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
		})
	}
}

func TestUpstreamDoTBootstrap(t *testing.T) {
	upstreams := []struct {
		address   string
		bootstrap []string
	}{{
		address:   "tls://one.one.one.one/",
		bootstrap: []string{"tls://1.1.1.1"},
	}, {
		address:   "tls://one.one.one.one/",
		bootstrap: []string{"https://1.1.1.1/dns-query"},
	}, {
		address: "tls://one.one.one.one/",
		// Cisco OpenDNS
		bootstrap: []string{"sdns://AQAAAAAAAAAADjIwOC42Ny4yMjAuMjIwILc1EUAgbyJdPivYItf9aR6hwzzI1maNDL4Ev6vKQ_t5GzIuZG5zY3J5cHQtY2VydC5vcGVuZG5zLmNvbQ"},
	}}

	for _, tc := range upstreams {
		t.Run(tc.address, func(t *testing.T) {
			u, err := AddressToUpstream(tc.address, &Options{
				Bootstrap: tc.bootstrap,
				Timeout:   timeout,
			})
			require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			checkUpstream(t, u, tc.address)
		})
	}
}

// Test for DoH and DoT upstreams with two bootstraps (only one is valid)
func TestUpstreamsInvalidBootstrap(t *testing.T) {
	upstreams := []struct {
		address   string
		bootstrap []string
	}{{
		address:   "tls://dns.adguard.com",
		bootstrap: []string{"1.1.1.1:555", "8.8.8.8:53"},
	}, {
		address:   "tls://dns.adguard.com:853",
		bootstrap: []string{"1.0.0.1", "8.8.8.8:535"},
	}, {
		address:   "https://1dot1dot1dot1.cloudflare-dns.com/dns-query",
		bootstrap: []string{"8.8.8.1", "1.0.0.1"},
	}, {
		address:   "https://doh.opendns.com:443/dns-query",
		bootstrap: []string{"1.2.3.4:79", "8.8.8.8:53"},
	}, {
		// Cloudflare DNS (DoH)
		address:   "sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk",
		bootstrap: []string{"8.8.8.8:53", "8.8.8.1:53"},
	}, {
		// AdGuard DNS (DNS-over-TLS)
		address:   "sdns://AwAAAAAAAAAAAAAPZG5zLmFkZ3VhcmQuY29t",
		bootstrap: []string{"1.2.3.4:55", "8.8.8.8"},
	}}

	for _, tc := range upstreams {
		t.Run(tc.address, func(t *testing.T) {
			u, err := AddressToUpstream(tc.address, &Options{
				Bootstrap: tc.bootstrap,
				Timeout:   timeout,
			})
			require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			checkUpstream(t, u, tc.address)
		})
	}

	_, err := AddressToUpstream("tls://example.org", &Options{
		Bootstrap: []string{"8.8.8.8", "asdfasdf"},
	})
	assert.Error(t, err) // bad bootstrap "asdfasdf"
}

func TestUpstreamsWithServerIP(t *testing.T) {
	// use invalid bootstrap to make sure it fails if tries to use it
	invalidBootstrap := []string{"1.2.3.4:55"}

	h := func(w dns.ResponseWriter, m *dns.Msg) {
		require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(m)))
	}
	dotSrv := startDoTServer(t, h)
	dohSrv := startDoHServer(t, testDoHServerOptions{})
	_, dohPort, err := net.SplitHostPort(dohSrv.addr)
	require.NoError(t, err)

	dotStamp := (&dnsstamps.ServerStamp{
		ServerAddrStr: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(dotSrv.port)).String(),
		Proto:         dnsstamps.StampProtoTypeTLS,
		ProviderName:  netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(dotSrv.port)).String(),
	}).String()
	dohStamp := (&dnsstamps.ServerStamp{
		ServerAddrStr: dohSrv.addr,
		Proto:         dnsstamps.StampProtoTypeDoH,
		ProviderName:  dohSrv.addr,
		Path:          "/dns-query",
	}).String()

	upstreams := []struct {
		name      string
		address   string
		serverIPs []net.IP
	}{{
		name:      "dot",
		address:   fmt.Sprintf("tls://some.dns.server:%d", dotSrv.port),
		serverIPs: []net.IP{netutil.IPv4Localhost().AsSlice()},
	}, {
		name:      "doh",
		address:   fmt.Sprintf("https://some.dns.server:%s/dns-query", dohPort),
		serverIPs: []net.IP{netutil.IPv4Localhost().AsSlice()},
	}, {
		name:      "dot_stamp",
		address:   dotStamp,
		serverIPs: nil,
	}, {
		name:      "doh_stamp",
		address:   dohStamp,
		serverIPs: nil,
	}}

	for _, tc := range upstreams {
		t.Run(tc.name, func(t *testing.T) {
			opts := &Options{
				Bootstrap:          invalidBootstrap,
				Timeout:            timeout,
				ServerIPAddrs:      tc.serverIPs,
				InsecureSkipVerify: true,
			}
			u, uErr := AddressToUpstream(tc.address, opts)
			require.NoError(t, uErr)
			testutil.CleanupAndRequireSuccess(t, u.Close)

			checkUpstream(t, u, tc.address)
		})
	}
}

func TestAddPort(t *testing.T) {
	testCases := []struct {
		name string
		want string
		host string
		port uint16
	}{{
		name: "empty",
		want: ":0",
		host: "",
		port: 0,
	}, {
		name: "hostname",
		want: "example.org:53",
		host: "example.org",
		port: 53,
	}, {
		name: "ipv4",
		want: "1.2.3.4:1",
		host: "1.2.3.4",
		port: 1,
	}, {
		name: "ipv6",
		want: "[::1]:1",
		host: "::1",
		port: 1,
	}, {
		name: "ipv6_with_brackets",
		want: "[::1]:1",
		host: "[::1]",
		port: 1,
	}, {
		name: "hostname_with_port",
		want: "example.org:54",
		host: "example.org:54",
		port: 53,
	}, {
		name: "ipv4_with_port",
		want: "1.2.3.4:2",
		host: "1.2.3.4:2",
		port: 1,
	}, {
		name: "ipv6_with_brackets_and_port",
		want: "[::1]:2",
		host: "[::1]:2",
		port: 1,
	}}

	for _, tc := range testCases {
		u := &url.URL{
			Host: tc.host,
		}

		t.Run(tc.name, func(t *testing.T) {
			addPort(u, tc.port)
			assert.Equal(t, tc.want, u.Host)
		})
	}
}

// checkUpstream sends a test message to the upstream and checks the result.
func checkUpstream(t *testing.T, u Upstream, addr string) {
	t.Helper()

	req := createTestMessage()
	reply, err := u.Exchange(req)
	require.NoErrorf(t, err, "couldn't talk to upstream %s", addr)

	requireResponse(t, req, reply)
}

// checkRaceCondition runs several goroutines in parallel and each of them calls
// checkUpstream several times.
func checkRaceCondition(u Upstream) {
	wg := sync.WaitGroup{}

	// The number of requests to run in every goroutine.
	reqCount := 10
	// The overall number of goroutines to run.
	goroutinesCount := 3

	makeRequests := func() {
		defer wg.Done()
		for i := 0; i < reqCount; i++ {
			req := createTestMessage()
			// Ignore exchange errors here, the point is to check for races.
			_, _ = u.Exchange(req)
		}
	}

	wg.Add(goroutinesCount)
	for i := 0; i < goroutinesCount; i++ {
		go makeRequests()
	}

	wg.Wait()
}

// createTestMessage creates a *dns.Msg that we use for tests and that we then
// check with requireResponse.
func createTestMessage() (m *dns.Msg) {
	return createHostTestMessage("google-public-dns-a.google.com")
}

// respondToTestMessage crafts a *dns.Msg response to a message created by
// createTestMessage.
func respondToTestMessage(m *dns.Msg) (resp *dns.Msg) {
	resp = &dns.Msg{}
	resp.SetReply(m)
	resp.Answer = append(resp.Answer, &dns.A{
		A: net.IPv4(8, 8, 8, 8),
		Hdr: dns.RR_Header{
			Name:   "google-public-dns-a.google.com.",
			Rrtype: dns.TypeA,
			Class:  dns.ClassINET,
			Ttl:    100,
		},
	})

	return resp
}

// createHostTestMessage creates a *dns.Msg with A request for the specified
// host name.
func createHostTestMessage(host string) (req *dns.Msg) {
	return &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id:               dns.Id(),
			RecursionDesired: true,
		},
		Question: []dns.Question{{
			Name:   dns.Fqdn(host),
			Qtype:  dns.TypeA,
			Qclass: dns.ClassINET,
		}},
	}
}

// requireResponse validates that the *dns.Msg is a valid response to the
// message created by createTestMessage.
func requireResponse(t require.TestingT, req, reply *dns.Msg) {
	require.NotNil(t, reply)
	require.Lenf(t, reply.Answer, 1, "wrong number of answers: %d", len(reply.Answer))
	require.Equal(t, req.Id, reply.Id)

	a, ok := reply.Answer[0].(*dns.A)
	require.Truef(t, ok, "wrong answer type: %v", reply.Answer[0])

	require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}

// createServerTLSConfig creates a test server TLS configuration. It returns
// a *tls.Config that can be used for both the server and the client and the
// root certificate pem-encoded.
// TODO(ameshkov): start using rootCAs in tests instead of InsecureVerify.
func createServerTLSConfig(
	tb testing.TB,
	tlsServerName string,
) (tlsConfig *tls.Config, rootCAs *x509.CertPool) {
	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
	require.NoError(tb, err)

	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	require.NoError(tb, err)

	notBefore := time.Now()
	notAfter := notBefore.Add(5 * 365 * time.Hour * 24)

	template := x509.Certificate{
		SerialNumber: serialNumber,
		Subject: pkix.Name{
			Organization: []string{"AdGuard Tests"},
		},
		NotBefore: notBefore,
		NotAfter:  notAfter,

		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
		IsCA:                  true,
	}
	template.DNSNames = append(template.DNSNames, tlsServerName)

	derBytes, err := x509.CreateCertificate(
		rand.Reader,
		&template,
		&template,
		publicKey(privateKey),
		privateKey,
	)
	require.NoError(tb, err)

	certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
	keyPem := pem.EncodeToMemory(
		&pem.Block{
			Type:  "RSA PRIVATE KEY",
			Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
		},
	)

	cert, err := tls.X509KeyPair(certPem, keyPem)
	require.NoError(tb, err)

	rootCAs = x509.NewCertPool()
	rootCAs.AppendCertsFromPEM(certPem)

	tlsConfig = &tls.Config{
		Certificates: []tls.Certificate{cert},
		ServerName:   tlsServerName,
		RootCAs:      rootCAs,
		MinVersion:   tls.VersionTLS12,
	}

	return tlsConfig, rootCAs
}

// publicKey extracts the public key from the specified private key.
func publicKey(priv any) (pub any) {
	switch k := priv.(type) {
	case *rsa.PrivateKey:
		return &k.PublicKey
	case *ecdsa.PrivateKey:
		return &k.PublicKey
	default:
		return nil
	}
}
07070100000086000081A4000000000000000000000001650C592100001410000000000000000000000000000000000000002D00000000dnsproxy-0.55.0/upstream/upstreamresolver.gopackage upstream

import (
	"context"
	"fmt"
	"net/netip"

	"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
	proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
	"github.com/AdguardTeam/dnsproxy/proxyutil"
	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/log"
	"github.com/miekg/dns"
)

// Resolver is an alias for [bootstrap.Resolver] to avoid the import cycle.
type Resolver = bootstrap.Resolver

// NewUpstreamResolver creates an upstream that can be used as [Resolver].
// resolverAddress format is the same as in the [AddressToUpstream], except that
// it also shouldn't need a bootstrap, i.e. have an IP address in hostname, or
// be a DNSCrypt.  resolverAddress must not be empty, use another [Resolver]
// instead, e.g.  [net.Resolver].
func NewUpstreamResolver(resolverAddress string, opts *Options) (r Resolver, err error) {
	upsOpts := &Options{}

	// TODO(ameshkov):  Aren't other options needed here?
	if opts != nil {
		upsOpts.Timeout = opts.Timeout
		upsOpts.VerifyServerCertificate = opts.VerifyServerCertificate
		upsOpts.PreferIPv6 = opts.PreferIPv6
	}

	ur := upstreamResolver{}

	ur.Upstream, err = AddressToUpstream(resolverAddress, upsOpts)
	if err != nil {
		err = fmt.Errorf("creating upstream: %w", err)
		log.Error("upstream bootstrap: %s", err)

		return ur, err
	}

	if err = validateBootstrap(ur.Upstream); err != nil {
		log.Error("upstream bootstrap %s: %s", resolverAddress, err)

		ur.Upstream = nil

		return ur, err
	}

	return ur, err
}

// validateBootstrap returns error if the upstream is not eligible to be a
// bootstrap DNS server.  DNSCrypt is always okay.  Plain DNS, DNS-over-TLS,
// DNS-over-HTTPS, and DNS-over-QUIC are okay only if those are defined by IP.
func validateBootstrap(upstream Upstream) (err error) {
	switch upstream := upstream.(type) {
	case *dnsCrypt:
		return nil
	case *dnsOverTLS:
		_, err = netip.ParseAddr(upstream.addr.Hostname())
	case *dnsOverHTTPS:
		_, err = netip.ParseAddr(upstream.addr.Hostname())
	case *dnsOverQUIC:
		_, err = netip.ParseAddr(upstream.addr.Hostname())
	case *plainDNS:
		_, err = netip.ParseAddr(upstream.addr.Hostname())
	default:
		err = fmt.Errorf("unknown upstream type: %T", upstream)
	}

	return errors.Annotate(err, "bootstrap %s: %w", upstream.Address())
}

// upstreamResolver is a wrapper around Upstream that implements the
// [bootstrap.Resolver] interface.  It sorts the resolved addresses preferring
// IPv4.
type upstreamResolver struct {
	// Upstream is embedded here to avoid implementing another Upstream's
	// methods.
	Upstream
}

// type check
var _ Resolver = upstreamResolver{}

// LookupNetIP implements the [Resolver] interface for upstreamResolver.
//
// TODO(e.burkov):  Use context.
func (r upstreamResolver) LookupNetIP(
	_ context.Context,
	network string,
	host string,
) (ipAddrs []netip.Addr, err error) {
	// TODO(e.burkov):  Investigate when [r.Upstream] is nil and why.
	if r.Upstream == nil || host == "" {
		return []netip.Addr{}, nil
	}

	host = dns.Fqdn(host)

	answers := make([][]dns.RR, 1, 2)
	var errs []error
	switch network {
	case "ip4", "ip6":
		qtype := dns.TypeA
		if network == "ip6" {
			qtype = dns.TypeAAAA
		}

		var resp *dns.Msg
		resp, err = r.resolve(host, qtype)
		if err != nil {
			return []netip.Addr{}, err
		}

		answers[0] = resp.Answer
	case "ip":
		resCh := make(chan *resolveResult, 2)

		go r.resolveAsync(resCh, host, dns.TypeA)
		go r.resolveAsync(resCh, host, dns.TypeAAAA)

		answers = answers[:0:cap(answers)]
		for i := 0; i < 2; i++ {
			res := <-resCh
			if res.err != nil {
				errs = append(errs, res.err)

				continue
			}

			answers = append(answers, res.resp.Answer)
		}
	default:
		return []netip.Addr{}, fmt.Errorf("unsupported network %s", network)
	}

	for _, ans := range answers {
		for _, rr := range ans {
			if addr, ok := netip.AddrFromSlice(proxyutil.IPFromRR(rr)); ok {
				ipAddrs = append(ipAddrs, addr)
			}
		}
	}

	// TODO(e.burkov):  Use [errors.Join] in Go 1.20.
	if len(ipAddrs) == 0 && len(errs) > 0 {
		return []netip.Addr{}, errs[0]
	}

	// Use the previous dnsproxy behavior: prefer IPv4 by default.
	//
	// TODO(a.garipov): Consider unexporting this entire method or
	// documenting that the order of addrs is undefined.
	proxynetutil.SortNetIPAddrs(ipAddrs, false)

	return ipAddrs, nil
}

// resolve performs a single DNS lookup of host.
func (r upstreamResolver) resolve(host string, qtype uint16) (resp *dns.Msg, err error) {
	req := &dns.Msg{
		MsgHdr: dns.MsgHdr{
			Id:               dns.Id(),
			RecursionDesired: true,
		},
		Question: []dns.Question{{
			Name:   host,
			Qtype:  qtype,
			Qclass: dns.ClassINET,
		}},
	}

	return r.Exchange(req)
}

// resolveResult is the result of a single concurrent lookup.
type resolveResult = struct {
	resp *dns.Msg
	err  error
}

// resolveAsync performs a single DNS lookup and sends the result to ch.  It's
// intended to be used as a goroutine.
func (r upstreamResolver) resolveAsync(
	resCh chan<- *resolveResult,
	host string,
	qtype uint16,
) {
	resp, err := r.resolve(host, qtype)
	resCh <- &resolveResult{resp: resp, err: err}
}
07070100000087000081A4000000000000000000000001650C59210000092E000000000000000000000000000000000000003200000000dnsproxy-0.55.0/upstream/upstreamresolver_test.gopackage upstream

import (
	"context"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestNewUpstreamResolver(t *testing.T) {
	r, err := NewUpstreamResolver("1.1.1.1:53", &Options{Timeout: 3 * time.Second})
	require.NoError(t, err)

	ipAddrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com")
	require.NoError(t, err)

	assert.NotEmpty(t, ipAddrs)
}

func TestNewUpstreamResolver_validity(t *testing.T) {
	withTimeoutOpt := &Options{Timeout: 3 * time.Second}

	testCases := []struct {
		name       string
		addr       string
		wantErrMsg string
	}{{
		name:       "udp",
		addr:       "1.1.1.1:53",
		wantErrMsg: "",
	}, {
		name:       "dot",
		addr:       "tls://1.1.1.1",
		wantErrMsg: "",
	}, {
		name:       "doh",
		addr:       "https://1.1.1.1/dns-query",
		wantErrMsg: "",
	}, {
		name:       "sdns",
		addr:       "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
		wantErrMsg: "",
	}, {
		name:       "tcp",
		addr:       "tcp://9.9.9.9",
		wantErrMsg: "",
	}, {
		name: "invalid_tls",
		addr: "tls://dns.adguard.com",
		wantErrMsg: `bootstrap tls://dns.adguard.com:853: ` +
			`ParseAddr("dns.adguard.com"): unexpected character (at "dns.adguard.com")`,
	}, {
		name: "invalid_https",
		addr: "https://dns.adguard.com/dns-query",
		wantErrMsg: `bootstrap https://dns.adguard.com:443/dns-query: ` +
			`ParseAddr("dns.adguard.com"): unexpected character (at "dns.adguard.com")`,
	}, {
		name: "invalid_tcp",
		addr: "tcp://dns.adguard.com",
		wantErrMsg: `bootstrap tcp://dns.adguard.com:53: ` +
			`ParseAddr("dns.adguard.com"): unexpected character (at "dns.adguard.com")`,
	}, {
		name: "invalid_no_scheme",
		addr: "dns.adguard.com",
		wantErrMsg: `bootstrap dns.adguard.com:53: ParseAddr("dns.adguard.com"): ` +
			`unexpected character (at "dns.adguard.com")`,
	}}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			r, err := NewUpstreamResolver(tc.addr, withTimeoutOpt)
			if tc.wantErrMsg != "" {
				assert.Equal(t, tc.wantErrMsg, err.Error())

				return
			}

			require.NoError(t, err)

			addrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com")
			require.NoError(t, err)

			assert.NotEmpty(t, addrs)
		})
	}
}
07070100000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000B00000000TRAILER!!!1097 blocks
openSUSE Build Service is sponsored by