diff --git a/LICENSE.md b/LICENSE.md index e071cb42..caf5ae03 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,8 +1,8 @@ # Dual license info The code in this repository is available under a dual licensing model: -1. Open Source License: The code, except for the contents of the "src-tauri/src/enterprise/" directory, is licensed under the AGPL license (this license). This applies to the open core components of the software. -2. Enterprise License: All code in this repository (including within the "src-tauri/src/enterprise/" directory) is licensed under a separate Enterprise License (see file src/enterprise/LICENSE.md). +1. Open Source License: The code, except for the contents of the "src-tauri/enterprise/" directory, is licensed under the AGPL license (this license). This applies to the open core components of the software. +2. Enterprise License: All code in this repository (including within the "src-tauri/enterprise/" directory) is licensed under a separate Enterprise License (see file src-tauri/enterprise/LICENSE.md). # GNU AFFERO GENERAL PUBLIC LICENSE diff --git a/flake.lock b/flake.lock index 37bef692..8992d0e1 100644 --- a/flake.lock +++ b/flake.lock @@ -12,6 +12,21 @@ }, "parent": [] }, + "crane": { + "locked": { + "lastModified": 1779130139, + "narHash": "sha256-BLrtr42azquO7MdGFU5a7KiMl3YpFlTeIXqy1fT5GlQ=", + "owner": "ipetkov", + "repo": "crane", + "rev": "edb38893982a3338972bb4a2ec7ce7c29ba10fd9", + "type": "github" + }, + "original": { + "owner": "ipetkov", + "repo": "crane", + "type": "github" + } + }, "defguard-ui": { "flake": false, "locked": { @@ -44,11 +59,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1778794387, - "narHash": "sha256-BL04pOS9453Awkeb9f90XBJXBSkWxN+vB7HIgnL0iMM=", + "lastModified": 1779351318, + "narHash": "sha256-f+JACbTqzZ+G92DSnXOUGRhGANb8Blh7CoeYOeBF8/U=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "8a1b0127302ea51e05bf4ea5a291743fac442406", + "rev": "4a29d733e8a7d5b824c3d8c958a946a9867b3eb2", "type": "github" }, "original": { @@ -87,6 +102,7 @@ "root": { "inputs": { "boringtun": "boringtun", + "crane": "crane", "defguard-ui": "defguard-ui", "flake-utils": "flake-utils", "nixpkgs": "nixpkgs", @@ -99,11 +115,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1778815121, - "narHash": "sha256-xlhD+1NVJbhrUUM2usRHW6iKWTXP2uw2Fo6sWJmLg8g=", + "lastModified": 1779419951, + "narHash": "sha256-dMX0PUslUHPajP6o8FEoRdFv9afq/dec4POR0vVfjK4=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "017351829a9356423afd2cca0dde9b63346c8ab3", + "rev": "5b5c521d6cae9ef4aa32f888eb2c0ce595c9be52", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 2621b4c0..8da37223 100644 --- a/flake.nix +++ b/flake.nix @@ -3,6 +3,7 @@ nixpkgs.url = "nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay.url = "github:oxalica/rust-overlay"; + crane.url = "github:ipetkov/crane"; # let git manage submodules self.submodules = true; @@ -25,26 +26,55 @@ nixpkgs, flake-utils, rust-overlay, + crane, ... }: flake-utils.lib.eachDefaultSystem (system: let - # add rust overlay - pkgs = import nixpkgs { + # Plain nixpkgs — used for packages and checks. + pkgs = import nixpkgs {inherit system;}; + + # nixpkgs with rust-overlay — only needed for the dev shell, which uses + # pkgs.rust-bin to get a customised Rust toolchain. + devPkgs = import nixpkgs { inherit system; overlays = [rust-overlay.overlays.default]; }; + + craneLib = crane.mkLib pkgs; + + defguard-client = pkgs.callPackage ./nix/package.nix { + inherit pkgs craneLib; + }; in { devShells.default = import ./nix/shell.nix { - inherit pkgs; + pkgs = devPkgs; + inherit crane; }; - packages.default = pkgs.callPackage ./nix/package.nix { - inherit pkgs; + packages = { + default = defguard-client; + inherit defguard-client; + defguard-service = + pkgs.runCommand "defguard-service" { + nativeBuildInputs = [pkgs.makeWrapper]; + } '' + mkdir -p $out/bin + cp ${defguard-client}/bin/defguard-service $out/bin/ + ''; + dg = + pkgs.runCommand "dg" { + nativeBuildInputs = [pkgs.makeWrapper]; + } '' + mkdir -p $out/bin + cp ${defguard-client}/bin/dg $out/bin/ + ''; }; + checks.default = defguard-client; + formatter = pkgs.alejandra; }) // { - nixosModules.default = import ./nix/nixos-module.nix; + nixosModules.default = import ./nix/nixos-module.nix {mkCraneLib = crane.mkLib;}; }; } diff --git a/nix/nixos-module.nix b/nix/nixos-module.nix index f3833d09..2205d510 100644 --- a/nix/nixos-module.nix +++ b/nix/nixos-module.nix @@ -1,61 +1,66 @@ -{ +{mkCraneLib}: { config, lib, pkgs, ... -}: -with lib; let - defguard-client = pkgs.callPackage ./package.nix {}; +}: let + craneLib = mkCraneLib pkgs; + defguard-client = pkgs.callPackage ./package.nix {inherit pkgs craneLib;}; cfg = config.programs.defguard-client; in { options.programs.defguard-client = { - enable = mkEnableOption "Defguard VPN client and service"; + enable = lib.mkEnableOption "Defguard VPN client and service"; - package = mkOption { - type = types.package; + package = lib.mkOption { + type = lib.types.package; default = defguard-client; description = "defguard-client package to use"; }; - logLevel = mkOption { - type = types.str; + logLevel = lib.mkOption { + type = lib.types.str; default = "info"; description = "Log level for defguard-service"; }; - statsPeriod = mkOption { - type = types.int; + statsPeriod = lib.mkOption { + type = lib.types.int; default = 30; description = "Interval in seconds for interface statistics updates"; }; }; - config = mkIf cfg.enable { - # Add client package + config = lib.mkIf cfg.enable { environment.systemPackages = [cfg.package]; - # Setup systemd service for the intrerface management daemon systemd.services.defguard-service = { description = "Defguard VPN Service"; + documentation = ["https://docs.defguard.net"]; wantedBy = ["multi-user.target"]; wants = ["network-online.target"]; after = ["network-online.target"]; serviceConfig = { - ExecStart = "${cfg.package}/bin/defguard-service --log-level ${cfg.logLevel} --stats-period ${toString cfg.statsPeriod}"; - ExecReload = "/bin/kill -HUP $MAINPID"; Group = "defguard"; - Restart = "on-failure"; - RestartSec = 2; + ExecStart = "${cfg.package}/bin/defguard-service --log-level ${cfg.logLevel} --stats-period ${toString cfg.statsPeriod}"; + ExecReload = "kill -HUP $MAINPID"; KillMode = "process"; KillSignal = "SIGINT"; LimitNOFILE = 65536; LimitNPROC = "infinity"; + Restart = "on-failure"; + RestartSec = 2; TasksMax = "infinity"; OOMScoreAdjust = -1000; + # Security hardening + NoNewPrivileges = true; + PrivateTmp = true; + ProtectControlGroups = true; + ProtectKernelModules = true; + RestrictRealtime = true; + LockPersonality = true; }; }; - # Make sure the defguard group exists users.groups.defguard = {}; }; } diff --git a/nix/package.nix b/nix/package.nix index d6dd9a69..ba297d4d 100644 --- a/nix/package.nix +++ b/nix/package.nix @@ -1,8 +1,7 @@ { pkgs, lib, - stdenv, - rustPlatform, + craneLib, rustc, cargo, makeDesktopItem, @@ -10,7 +9,6 @@ fetchPnpmDeps, }: let pname = "defguard-client"; - # Automatically read version from Cargo.toml version = (fromTOML (builtins.readFile ../src-tauri/Cargo.toml)).workspace.package.version; desktopItem = makeDesktopItem { @@ -43,60 +41,113 @@ libayatana-indicator ayatana-ido libdbusmenu-gtk3 - desktop-file-utils - iproute2 - lsb-release - openresolv ]; - nativeBuildInputs = [ + # Rust/cargo inputs shared by buildDepsOnly and the main build. + cargoNativeBuildInputs = [ rustc cargo pkgs.pkg-config pkgs.gobject-introspection pkgs.cargo-tauri - pkgs.nodejs_24 pkgs.protobuf - pnpm - # configures pnpm to use pre-fetched dependencies - pnpmConfigHook - # configures cargo to use pre-fetched dependencies - rustPlatform.cargoSetupHook - # helper to add runtime binary & library deps paths - pkgs.makeWrapper - pkgs.wrapGAppsHook3 ]; -in - stdenv.mkDerivation (finalAttrs: rec { - inherit pname version buildInputs nativeBuildInputs; + # Source filter for buildDepsOnly: Cargo files plus extras needed by build.rs + # (proto files, tauri configs, capabilities, sqlx offline cache). + depsSourceFilter = path: type: + (craneLib.filterCargoSources path type) + || (lib.hasSuffix ".proto" path) + || (lib.hasSuffix "tauri.conf.json" path) + || (lib.hasInfix "/capabilities/" path) + || (lib.hasInfix "/.sqlx/" path) + || (lib.hasSuffix ".sql" path); + + depsSrc = lib.cleanSourceWith { + src = craneLib.path ../src-tauri; + filter = depsSourceFilter; + }; + + cargoVendorDir = craneLib.vendorCargoDeps { + src = craneLib.path ../src-tauri; + }; + + # Pre-compile cargo dependencies; cached as long as Cargo.lock is unchanged. + # Features must match the main build. + cargoArtifacts = craneLib.buildDepsOnly { + inherit pname; + inherit version buildInputs cargoVendorDir; + src = depsSrc; + nativeBuildInputs = cargoNativeBuildInputs; + cargoExtraArgs = "--features custom-protocol"; + VERGEN_IDEMPOTENT = "true"; + SQLX_OFFLINE = "true"; + }; + + # Prefetch pnpm dependencies. + # Explicit pnpm_10 keeps fetchPnpmDeps and pnpmConfigHook on the same major version. + pnpmDeps = fetchPnpmDeps { + inherit pname version pnpm; src = ../.; + fetcherVersion = 3; + hash = "sha256-tqXLmnPgq3G79/RoZjFcXIbDMZB9M49Pz7Lm5rmMs14="; + }; +in + craneLib.mkCargoDerivation { + inherit pname version buildInputs cargoArtifacts cargoVendorDir pnpmDeps; - # prefetch cargo dependencies - cargoRoot = "src-tauri"; - buildAndTestSubdir = "src-tauri"; + src = ../.; - cargoDeps = rustPlatform.importCargoLock { - lockFile = ../src-tauri/Cargo.lock; - }; + nativeBuildInputs = + cargoNativeBuildInputs + ++ [ + pkgs.makeWrapper + pkgs.wrapGAppsHook3 + pkgs.nodejs_24 + pnpm + pnpmConfigHook + ]; + + # Pin CARGO_TARGET_DIR before crane's inheritCargoArtifacts hook runs so + # extraction and tauri's cargo invocation both land in src-tauri/target. + postUnpack = '' + export CARGO_TARGET_DIR="$NIX_BUILD_TOP/$sourceRoot/src-tauri/target" + ''; - # prefetch pnpm dependencies - pnpmDeps = fetchPnpmDeps { - inherit - (finalAttrs) - pname - version - src - ; - - fetcherVersion = 2; - hash = "sha256-vDLgpFaO+48s+tj1/2m2fgNJpCfnNkFJpQkC4Xah59E="; - }; + # Required by mkCargoDerivation even when buildPhase is fully overridden. + buildPhaseCargoCommand = ""; + + preBuild = '' + # Workspace-member build scripts were compiled in buildDepsOnly's source + # tree (/build/source/) with that path baked in; remove them so cargo + # recompiles them against the current tree. Dep .rlib/.rmeta are kept. + rm -rf src-tauri/target/release/build/defguard* + rm -rf src-tauri/target/release/build/common* + rm -rf src-tauri/target/release/.fingerprint/defguard* + rm -rf src-tauri/target/release/.fingerprint/common* + + # tauri_build::build() reads OUT_DIR metadata written by tauri's own + # build script during buildDepsOnly (pointing to /build/source/). + # Remove tauri's build outputs and build-script-run fingerprints so + # cargo re-runs the build script and refreshes OUT_DIR to the current + # path. libtauri*.rlib in deps/ is untouched. + rm -rf src-tauri/target/release/build/tauri-* + find src-tauri/target/release/.fingerprint \ + -maxdepth 1 -type d \( -name 'tauri-*' -o -name 'tauri_*' \) \ + -exec rm -f '{}/build-script-run' \; + ''; buildPhase = '' runHook preBuild - pnpm tauri build --verbose + # Build the frontend first; tauri's beforeBuildCommand is suppressed + # below to avoid running pnpm build a second time. + pnpm build + + # --config replaces the build section from tauri.linux.conf.json. + pnpm tauri build \ + --config '{"build":{"beforeBuildCommand":""}}' \ + --bundles deb runHook postBuild ''; @@ -104,55 +155,45 @@ in installPhase = '' runHook preInstall - mkdir -p $out/bin - - # copy client binary - install -Dm755 src-tauri/target/release/${pname} $out/bin/${pname} + # tauri always writes to src-tauri/target regardless of $CARGO_TARGET_DIR. + local targetDir="src-tauri/target/release" - # copy background service binary - install -Dm755 src-tauri/target/release/defguard-service $out/bin/defguard-service - - # copy CLI binary - install -Dm755 src-tauri/target/release/dg $out/bin/dg + mkdir -p $out/bin + install -Dm755 "$targetDir/${pname}" $out/bin/${pname} + install -Dm755 "$targetDir/defguard-service" $out/bin/defguard-service + install -Dm755 "$targetDir/dg" $out/bin/dg - # Copy resources directory (for tray icons, etc.) mkdir -p $out/lib/${pname} cp -r src-tauri/resources $out/lib/${pname}/ - # install desktop entry mkdir -p $out/share/applications cp ${desktopItem}/share/applications/* $out/share/applications/ - # install icon files mkdir -p $out/share/icons/hicolor/{32x32,128x128}/apps - install -Dm644 src-tauri/icons/32x32.png $out/share/icons/hicolor/32x32/apps/${pname}.png + install -Dm644 src-tauri/icons/32x32.png $out/share/icons/hicolor/32x32/apps/${pname}.png install -Dm644 src-tauri/icons/128x128.png $out/share/icons/hicolor/128x128/apps/${pname}.png runHook postInstall ''; - # add extra args to wrapGAppsHook3 wrapper preFixup = '' gappsWrapperArgs+=( - --prefix PATH : ${ - lib.makeBinPath [ - # `defguard-service` needs `ip` to manage WireGuard - pkgs.iproute2 - # `defguard-service` needs `resolvconf` to manage DNS - pkgs.openresolv - # `defguard-client` needs `update-desktop-database` and `lsb_release` - pkgs.desktop-file-utils - pkgs.lsb-release - ] - } - --prefix LD_LIBRARY_PATH : ${ - lib.makeLibraryPath [ - pkgs.libayatana-appindicator - ] - } + --prefix PATH : ${lib.makeBinPath [pkgs.iproute2 pkgs.desktop-file-utils pkgs.lsb-release]} + --suffix PATH : ${lib.makeBinPath [pkgs.openresolv]} + --prefix LD_LIBRARY_PATH : ${lib.makeLibraryPath [pkgs.libayatana-appindicator]} ) ''; + VERGEN_IDEMPOTENT = "true"; + SQLX_OFFLINE = "true"; + doInstallCargoArtifacts = false; + + # passthru attrs are ignored by the build but addressable by external tools: + # pnpmDeps — referenced by the update-pnpm-hash.yaml CI workflow + passthru = { + inherit pnpmDeps; + }; + meta = with lib; { description = "Defguard VPN Client"; homepage = "https://defguard.net"; @@ -160,4 +201,4 @@ in maintainers = with maintainers; [wojcik91]; platforms = platforms.linux; }; - }) + } diff --git a/nix/shell.nix b/nix/shell.nix index bd885377..3ed2e761 100644 --- a/nix/shell.nix +++ b/nix/shell.nix @@ -1,18 +1,17 @@ -{pkgs ? import {}}: let +{ + pkgs, + crane, +}: let # add development-related cargo tooling rustToolchain = pkgs.rust-bin.stable.latest.default.override { extensions = ["rust-analyzer" "rust-src" "rustfmt" "clippy"]; targets = ["x86_64-apple-darwin" "aarch64-apple-darwin" "x86_64-pc-windows-gnu"]; }; - # share custom toolchain with package - rustPlatform = pkgs.makeRustPlatform { - cargo = rustToolchain; - rustc = rustToolchain; - }; + craneLib = crane.mkLib pkgs; defguard-client = pkgs.callPackage ./package.nix { - inherit rustPlatform; + inherit craneLib; cargo = rustToolchain; rustc = rustToolchain; }; @@ -33,7 +32,6 @@ in sqlx-cli vtsls trivy - just ]; shellHook = with pkgs; '' diff --git a/src-tauri/.sqlx/query-157a0344c45c1b0567e8815c3f7b3667b753fb0b4b4c10d9d497b8f403ccb89d.json b/src-tauri/.sqlx/query-157a0344c45c1b0567e8815c3f7b3667b753fb0b4b4c10d9d497b8f403ccb89d.json new file mode 100644 index 00000000..67d032ff --- /dev/null +++ b/src-tauri/.sqlx/query-157a0344c45c1b0567e8815c3f7b3667b753fb0b4b4c10d9d497b8f403ccb89d.json @@ -0,0 +1,74 @@ +{ + "db_name": "SQLite", + "query": "SELECT id \"id: _\", name, uuid, url, proxy_url, username, token \"token?\", client_traffic_policy, enterprise_enabled, openid_display_name FROM instance WHERE name = $1;", + "describe": { + "columns": [ + { + "name": "id: _", + "ordinal": 0, + "type_info": "Integer" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "uuid", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "url", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "proxy_url", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "username", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "token?", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "client_traffic_policy", + "ordinal": 7, + "type_info": "Integer" + }, + { + "name": "enterprise_enabled", + "ordinal": 8, + "type_info": "Bool" + }, + { + "name": "openid_display_name", + "ordinal": 9, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + true, + false, + false, + true + ] + }, + "hash": "157a0344c45c1b0567e8815c3f7b3667b753fb0b4b4c10d9d497b8f403ccb89d" +} diff --git a/src-tauri/.sqlx/query-3bedd8a0e3a8d4b76330ba0f81d82cf1590e6d15ba30360c41b0a5a3482df3df.json b/src-tauri/.sqlx/query-3bedd8a0e3a8d4b76330ba0f81d82cf1590e6d15ba30360c41b0a5a3482df3df.json deleted file mode 100644 index 9f7beee7..00000000 --- a/src-tauri/.sqlx/query-3bedd8a0e3a8d4b76330ba0f81d82cf1590e6d15ba30360c41b0a5a3482df3df.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "SQLite", - "query": "SELECT EXISTS (SELECT 1 FROM location WHERE service_location_mode <= $1)", - "describe": { - "columns": [ - { - "name": "EXISTS (SELECT 1 FROM location WHERE service_location_mode <= $1)", - "ordinal": 0, - "type_info": "Integer" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false - ] - }, - "hash": "3bedd8a0e3a8d4b76330ba0f81d82cf1590e6d15ba30360c41b0a5a3482df3df" -} diff --git a/src-tauri/.sqlx/query-c6a5e793cccc520039e28da8b4fb73e0c79c6a8d671c300ec2ea3eb0d58342b5.json b/src-tauri/.sqlx/query-c6a5e793cccc520039e28da8b4fb73e0c79c6a8d671c300ec2ea3eb0d58342b5.json deleted file mode 100644 index 2436bff9..00000000 --- a/src-tauri/.sqlx/query-c6a5e793cccc520039e28da8b4fb73e0c79c6a8d671c300ec2ea3eb0d58342b5.json +++ /dev/null @@ -1,104 +0,0 @@ -{ - "db_name": "SQLite", - "query": "SELECT id, instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\", mfa_method \"mfa_method: _\", posture_check_required FROM location WHERE service_location_mode <= $1 ORDER BY name ASC", - "describe": { - "columns": [ - { - "name": "id", - "ordinal": 0, - "type_info": "Integer" - }, - { - "name": "instance_id", - "ordinal": 1, - "type_info": "Integer" - }, - { - "name": "name", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "address", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "pubkey", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "endpoint", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "allowed_ips", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "dns", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "network_id", - "ordinal": 8, - "type_info": "Integer" - }, - { - "name": "route_all_traffic", - "ordinal": 9, - "type_info": "Bool" - }, - { - "name": "keepalive_interval", - "ordinal": 10, - "type_info": "Integer" - }, - { - "name": "location_mfa_mode: LocationMfaMode", - "ordinal": 11, - "type_info": "Integer" - }, - { - "name": "service_location_mode: ServiceLocationMode", - "ordinal": 12, - "type_info": "Integer" - }, - { - "name": "mfa_method: _", - "ordinal": 13, - "type_info": "Integer" - }, - { - "name": "posture_check_required", - "ordinal": 14, - "type_info": "Bool" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - true, - false, - false, - false, - false, - false, - true, - false - ] - }, - "hash": "c6a5e793cccc520039e28da8b4fb73e0c79c6a8d671c300ec2ea3eb0d58342b5" -} diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index ad6706a2..36f3d982 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1044,13 +1044,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "common" -version = "2.1.0" -dependencies = [ - "nix", -] - [[package]] name = "concurrent-queue" version = "2.5.0" @@ -1415,15 +1408,21 @@ dependencies = [ "block2 0.6.2", "chrono", "clap", - "common", "dark-light", + "defguard-client-common", + "defguard-client-config-sync", + "defguard-client-core", + "defguard-client-posture", + "defguard-client-proto", + "defguard-client-provisioning", + "defguard-client-service-locations", "defguard_wireguard_rs", "dirs-next", "futures-core", "hyper-util", "known-folders", "log", - "nix", + "nix 0.31.3", "objc2 0.6.4", "objc2-app-kit", "objc2-foundation 0.3.2", @@ -1462,7 +1461,6 @@ dependencies = [ "tokio-util", "tonic", "tonic-prost", - "tonic-prost-build", "tower", "tracing", "tracing-appender", @@ -1477,12 +1475,159 @@ dependencies = [ "x25519-dalek", ] +[[package]] +name = "defguard-client-common" +version = "2.1.0" +dependencies = [ + "nix 0.31.3", +] + +[[package]] +name = "defguard-client-config-sync" +version = "2.1.0" +dependencies = [ + "defguard-client-core", + "defguard-client-proto", + "defguard-client-service-locations", + "log", + "reqwest 0.13.4", + "semver", + "serde", + "serde_json", + "sqlx", + "tokio", + "tonic", +] + +[[package]] +name = "defguard-client-core" +version = "2.1.0" +dependencies = [ + "base64 0.22.1", + "block2 0.6.2", + "chrono", + "defguard-client-common", + "defguard-client-proto", + "defguard_wireguard_rs", + "dirs-next", + "hyper-util", + "log", + "nix 0.31.3", + "objc2 0.6.4", + "objc2-app-kit", + "objc2-foundation 0.3.2", + "objc2-network-extension", + "os_info", + "prost", + "reqwest 0.13.4", + "rust-ini", + "semver", + "serde", + "serde_json", + "serde_with", + "sqlx", + "struct-patch", + "strum", + "thiserror 2.0.18", + "tokio", + "tonic", + "tower", + "windows-sys 0.61.2", + "x25519-dalek", +] + +[[package]] +name = "defguard-client-posture" +version = "2.1.0" +dependencies = [ + "defguard-client-core", + "defguard-client-proto", + "log", + "reqwest 0.13.4", + "serde", + "serde_json", + "sysinfo", +] + +[[package]] +name = "defguard-client-proto" +version = "2.1.0" +dependencies = [ + "defguard_wireguard_rs", + "prost", + "serde", + "serde_with", + "tonic", + "tonic-prost", + "tonic-prost-build", + "x25519-dalek", +] + +[[package]] +name = "defguard-client-provisioning" +version = "2.1.0" +dependencies = [ + "defguard-client-core", + "log", + "serde", + "serde_json", +] + +[[package]] +name = "defguard-client-service" +version = "2.1.0" +dependencies = [ + "anyhow", + "async-stream", + "clap", + "defguard-client-common", + "defguard-client-posture", + "defguard-client-proto", + "defguard-client-service-locations", + "defguard_wireguard_rs", + "futures-core", + "log", + "nix 0.30.1", + "serde", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tonic", + "tracing", + "tracing-appender", + "tracing-subscriber", + "windows-core 0.61.2", + "windows-service", + "windows-sys 0.61.2", +] + +[[package]] +name = "defguard-client-service-locations" +version = "2.1.0" +dependencies = [ + "base64 0.22.1", + "defguard-client-common", + "defguard-client-core", + "defguard-client-proto", + "defguard_wireguard_rs", + "known-folders", + "log", + "prost", + "serde", + "serde_json", + "thiserror 2.0.18", + "windows 0.62.2", + "windows-acl", + "windows-service", + "windows-sys 0.61.2", +] + [[package]] name = "defguard-dg" version = "2.1.0" dependencies = [ "clap", - "common", + "defguard-client-common", "defguard_wireguard_rs", "dirs-next", "prost", @@ -1513,7 +1658,7 @@ dependencies = [ "ip_network", "ip_network_table", "libc", - "nix", + "nix 0.31.3", "parking_lot", "ring", "socket2", @@ -1541,7 +1686,7 @@ dependencies = [ "netlink-packet-utils", "netlink-packet-wireguard", "netlink-sys", - "nix", + "nix 0.31.3", "regex", "serde", "thiserror 2.0.18", @@ -3720,6 +3865,18 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "nix" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +dependencies = [ + "bitflags 2.11.1", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nix" version = "0.31.3" @@ -4265,7 +4422,7 @@ checksum = "9cf20a545b305cf1da722b236b5155c9bb35f1d5ceb28c048bd96ca842f41b5b" dependencies = [ "android_system_properties", "log", - "nix", + "nix 0.31.3", "objc2 0.6.4", "objc2-foundation 0.3.2", "objc2-ui-kit", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 4ab353d8..a0b46ce9 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -1,15 +1,20 @@ [workspace] -members = ["cli", "common"] -default-members = [".", "cli"] +members = ["cli", "common", "client-proto", "core", "daemon", "enterprise/posture", "enterprise/provisioning", "enterprise/config-sync", "enterprise/service-locations"] +default-members = [".", "cli", "daemon"] [workspace.dependencies] +base64 = "0.22" clap = { version = "4.5", features = ["cargo", "derive", "env"] } defguard_wireguard_rs = "0.9" dirs-next = "2.0" +log = { version = "0.4", features = ["serde"] } prost = "0.14" reqwest = { version = "0.13", features = ["cookies", "json"] } +semver = { version = "1.0", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +serde_with = "3.11" +sqlx = { version = "0.8", features = ["chrono", "runtime-tokio", "sqlite", "uuid", "macros"] } thiserror = "2.0" tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] } tonic = { version = "0.14", default-features = false, features = [ @@ -49,42 +54,37 @@ version.workspace = true [[bin]] name = "defguard-client" -[[bin]] -name = "defguard-service" -required-features = ["service"] - [build-dependencies] tauri-build = { version = "2", features = [] } -tonic-prost-build.workspace = true vergen-git2 = { version = "9.1", features = ["build"] } [dependencies] anyhow = "1.0" -base64 = "0.22" +base64.workspace = true clap.workspace = true chrono = { version = "0.4", features = ["serde"] } -common = { path = "common" } +defguard-client-proto = { path = "client-proto" } +defguard-client-core = { path = "core" } +defguard-client-posture = { path = "enterprise/posture" } +defguard-client-config-sync = { path = "enterprise/config-sync" } +defguard-client-service-locations = { path = "enterprise/service-locations" } +defguard-client-provisioning = { path = "enterprise/provisioning" } +defguard-client-common = { path = "common" } dark-light = "2.0" defguard_wireguard_rs = { workspace = true, features = ["check_dependencies"] } dirs-next.workspace = true hyper-util = "0.1" -log = { version = "0.4", features = ["serde"] } +log.workspace = true os_info = { version = "3.14", default-features = false } prost.workspace = true regex = "1.12" reqwest.workspace = true rust-ini = "0.21" -semver = "1.0" +semver.workspace = true serde.workspace = true serde_json.workspace = true -serde_with = "3.11" -sqlx = { version = "0.8", features = [ - "chrono", - "sqlite", - "runtime-tokio", - "uuid", - "macros", -] } +serde_with.workspace = true +sqlx.workspace = true struct-patch = "0.12" strum = { version = "0.28", features = ["derive"] } sysinfo = { version = "0.39", default-features = false, features = ["apple-app-store", "system"] } @@ -173,7 +173,6 @@ wmi = {version = "0.18", default-features = false} # If you use cargo directly instead of tauri's cli you can use this feature flag to switch between tauri's `dev` and `build` modes. # DO NOT REMOVE!! custom-protocol = ["tauri/custom-protocol"] -service = [] [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/src-tauri/build.rs b/src-tauri/build.rs index 0bbd4977..7ab8755b 100644 --- a/src-tauri/build.rs +++ b/src-tauri/build.rs @@ -13,29 +13,6 @@ fn main() -> Result<(), Box> { let git2 = Git2Builder::default().branch(true).sha(true).build()?; Emitter::default().add_instructions(&git2)?.emit()?; - tonic_prost_build::configure() - // Enable optional fields. - .protoc_arg("--experimental_allow_proto3_optional") - // Make sure empty DNS is deserialized correctly as `None`. - .type_attribute(".DeviceConfig", "#[serde_as]") - .field_attribute( - ".DeviceConfig.dns", - "#[serde_as(deserialize_as = \"NoneAsEmptyString\")]", - ) - // Make all messages serde-serializable. - .type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]") - .compile_protos( - &[ - "proto/v1/client/client.proto", - "proto/v1/core/proxy.proto", - "proto/enterprise/v2/posture/posture.proto", - "proto/common/client_types.proto", - ], - &["proto"], - )?; - tauri_build::build(); - - println!("cargo:rerun-if-changed=proto"); Ok(()) } diff --git a/src-tauri/cli/Cargo.toml b/src-tauri/cli/Cargo.toml index f38c6cf9..c9b511d0 100644 --- a/src-tauri/cli/Cargo.toml +++ b/src-tauri/cli/Cargo.toml @@ -12,7 +12,7 @@ tonic-prost-build.workspace = true [dependencies] clap.workspace = true -common = { path = "../common" } +common = { package = "defguard-client-common", path = "../common" } defguard_wireguard_rs = { workspace = true, features = ["check_dependencies"] } dirs-next.workspace = true prost.workspace = true diff --git a/src-tauri/client-proto/Cargo.toml b/src-tauri/client-proto/Cargo.toml new file mode 100644 index 00000000..5cd0528d --- /dev/null +++ b/src-tauri/client-proto/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "defguard-client-proto" +description = "Protobuf definitions for the Defguard desktop client" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license-file = "../../LICENSE.md" +rust-version.workspace = true +version.workspace = true + +[build-dependencies] +tonic-prost-build.workspace = true + +[dependencies] +prost.workspace = true +serde.workspace = true +serde_with = "3.11" +tonic.workspace = true +tonic-prost.workspace = true + +defguard_wireguard_rs.workspace = true + +[dev-dependencies] +x25519-dalek = { version = "2", features = ["getrandom", "static_secrets"] } diff --git a/src-tauri/client-proto/build.rs b/src-tauri/client-proto/build.rs new file mode 100644 index 00000000..95f41e21 --- /dev/null +++ b/src-tauri/client-proto/build.rs @@ -0,0 +1,26 @@ +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../proto"); + + tonic_prost_build::configure() + // Enable optional fields. + .protoc_arg("--experimental_allow_proto3_optional") + // Make sure empty DNS is deserialized correctly as `None`. + .type_attribute(".DeviceConfig", "#[serde_as]") + .field_attribute( + ".DeviceConfig.dns", + "#[serde_as(deserialize_as = \"NoneAsEmptyString\")]", + ) + // Make all messages serde-serializable. + .type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]") + .compile_protos( + &[ + "../proto/v1/client/client.proto", + "../proto/v1/core/proxy.proto", + "../proto/enterprise/v2/posture/posture.proto", + "../proto/common/client_types.proto", + ], + &["../proto"], + )?; + + Ok(()) +} diff --git a/src-tauri/client-proto/src/conversions.rs b/src-tauri/client-proto/src/conversions.rs new file mode 100644 index 00000000..a0ac7ec4 --- /dev/null +++ b/src-tauri/client-proto/src/conversions.rs @@ -0,0 +1,141 @@ +use std::{ + str::FromStr, + time::{Duration, UNIX_EPOCH}, +}; + +use defguard_wireguard_rs::{ + host::Host, key::Key, net::IpAddrMask, peer::Peer, InterfaceConfiguration, +}; + +use crate::defguard::client::v1::{InterfaceConfig, InterfaceData, Peer as ProtoPeer}; + +impl From for InterfaceConfig { + fn from(config: InterfaceConfiguration) -> Self { + Self { + name: config.name, + prvkey: config.prvkey, + address: config + .addresses + .iter() + .map(ToString::to_string) + .collect::>() + .join(","), + port: u32::from(config.port), + peers: config.peers.into_iter().map(Into::into).collect(), + mtu: config.mtu, + } + } +} + +impl From for InterfaceConfiguration { + fn from(config: InterfaceConfig) -> Self { + let addresses = config + .address + .split(',') + .filter_map(|ip| IpAddrMask::from_str(ip.trim()).ok()) + .collect(); + Self { + name: config.name, + prvkey: config.prvkey, + addresses, + port: config.port as u16, + peers: config.peers.into_iter().map(Into::into).collect(), + mtu: config.mtu, + fwmark: None, // TODO: add to config + } + } +} + +impl From for ProtoPeer { + fn from(peer: Peer) -> Self { + Self { + public_key: peer.public_key.to_lower_hex(), + preshared_key: peer.preshared_key.map(|key| key.to_lower_hex()), + protocol_version: peer.protocol_version, + endpoint: peer.endpoint.map(|addr| addr.to_string()), + last_handshake: peer.last_handshake.map(|time| { + time.duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() + }), + tx_bytes: peer.tx_bytes, + rx_bytes: peer.rx_bytes, + persistent_keepalive_interval: peer.persistent_keepalive_interval.map(u32::from), + allowed_ips: peer + .allowed_ips + .into_iter() + .map(|addr| addr.to_string()) + .collect(), + } + } +} + +impl From for Peer { + fn from(peer: ProtoPeer) -> Self { + Self { + public_key: Key::decode(peer.public_key).expect("Failed to parse public key"), + preshared_key: peer + .preshared_key + .map(|key| Key::decode(key).expect("Failed to parse preshared key: {key}")), + protocol_version: peer.protocol_version, + endpoint: peer.endpoint.map(|addr| { + addr.parse() + .expect("Failed to parse endpoint address: {addr}") + }), + last_handshake: peer + .last_handshake + .map(|timestamp| UNIX_EPOCH + Duration::from_secs(timestamp)), + tx_bytes: peer.tx_bytes, + rx_bytes: peer.rx_bytes, + persistent_keepalive_interval: peer + .persistent_keepalive_interval + .and_then(|interval| u16::try_from(interval).ok()), + allowed_ips: peer + .allowed_ips + .into_iter() + .map(|addr| addr.parse().expect("Failed to parse allowed IP: {addr}")) + .collect(), + } + } +} + +impl From for InterfaceData { + fn from(host: Host) -> Self { + Self { + listen_port: u32::from(host.listen_port), + peers: host.peers.into_values().map(Into::into).collect(), + } + } +} + +#[cfg(test)] +mod tests { + use std::time::SystemTime; + + use defguard_wireguard_rs::{key::Key, net::IpAddrMask, peer::Peer}; + use x25519_dalek::{EphemeralSecret, PublicKey}; + + use super::*; + + #[test] + fn convert_peer() { + let secret = EphemeralSecret::random(); + let key = PublicKey::from(&secret); + let peer_key: Key = key.as_ref().try_into().unwrap(); + let mut base_peer = Peer::new(peer_key); + let addr = IpAddrMask::from_str("10.20.30.2/32").unwrap(); + base_peer.allowed_ips.push(addr); + // Workaround since nanoseconds are lost in conversion. + base_peer.last_handshake = Some(SystemTime::UNIX_EPOCH); + base_peer.protocol_version = Some(3); + base_peer.endpoint = Some("127.0.0.1:8080".parse().unwrap()); + base_peer.tx_bytes = 100; + base_peer.rx_bytes = 200; + + let proto_peer: ProtoPeer = base_peer.clone().into(); + + let converted_peer: Peer = proto_peer.into(); + + assert_eq!(base_peer, converted_peer); + } +} diff --git a/src-tauri/client-proto/src/lib.rs b/src-tauri/client-proto/src/lib.rs new file mode 100644 index 00000000..36eb7547 --- /dev/null +++ b/src-tauri/client-proto/src/lib.rs @@ -0,0 +1,28 @@ +pub mod conversions; +pub mod posture_ext; + +pub mod defguard { + pub mod client_types { + tonic::include_proto!("defguard.client_types"); + } + + pub mod client { + pub mod v1 { + tonic::include_proto!("defguard.client.v1"); + } + } + + pub mod proxy { + pub mod v1 { + tonic::include_proto!("defguard.proxy.v1"); + } + } + + pub mod enterprise { + pub mod posture { + pub mod v2 { + tonic::include_proto!("defguard.enterprise.posture.v2"); + } + } + } +} diff --git a/src-tauri/client-proto/src/posture_ext.rs b/src-tauri/client-proto/src/posture_ext.rs new file mode 100644 index 00000000..2e6632a4 --- /dev/null +++ b/src-tauri/client-proto/src/posture_ext.rs @@ -0,0 +1,52 @@ +use std::fmt; + +use crate::defguard::enterprise::posture::v2::{ + bool_check, int32_check, string_check, BoolCheck, Int32Check, StringCheck, UnavailableReason, +}; + +impl fmt::Display for UnavailableReason { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Unspecified => f.write_str("unspecified"), + Self::DetectionFailed => f.write_str("detection failed"), + Self::NotApplicable => f.write_str("not applicable on this platform"), + Self::InsufficientPermissions => f.write_str("insufficient permissions"), + } + } +} + +/// Convert `Result` to `BoolCheck`. +impl From> for BoolCheck { + fn from(value: Result) -> Self { + Self { + result: Some(match value { + Ok(inner) => bool_check::Result::Value(inner), + Err(err) => bool_check::Result::Unavailable(err as i32), + }), + } + } +} + +/// Convert `Result` to `Int32Check`. +impl From> for Int32Check { + fn from(value: Result) -> Self { + Self { + result: Some(match value { + Ok(inner) => int32_check::Result::Value(inner), + Err(err) => int32_check::Result::Unavailable(err as i32), + }), + } + } +} + +/// Convert `Result` to `StringCheck`. +impl From> for StringCheck { + fn from(value: Result) -> Self { + Self { + result: Some(match value { + Ok(inner) => string_check::Result::Value(inner), + Err(err) => string_check::Result::Unavailable(err as i32), + }), + } + } +} diff --git a/src-tauri/common/Cargo.toml b/src-tauri/common/Cargo.toml index 43c7fd19..800990ef 100644 --- a/src-tauri/common/Cargo.toml +++ b/src-tauri/common/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "common" +name = "defguard-client-common" authors.workspace = true edition.workspace = true homepage.workspace = true diff --git a/src-tauri/common/src/lib.rs b/src-tauri/common/src/lib.rs index 190fcbd3..65ebabda 100644 --- a/src-tauri/common/src/lib.rs +++ b/src-tauri/common/src/lib.rs @@ -1,5 +1,8 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener}; +/// Package version from the workspace (shared across all binaries). +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); + /// Obtain a free TCP port on localhost. #[must_use] pub fn find_free_tcp_port() -> Option { diff --git a/src-tauri/core/Cargo.toml b/src-tauri/core/Cargo.toml new file mode 100644 index 00000000..044f2d64 --- /dev/null +++ b/src-tauri/core/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "defguard-client-core" +description = "Shared business logic for the Defguard desktop client (Tauri-free)" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license-file = "../../LICENSE.md" +rust-version.workspace = true +version.workspace = true + +[dependencies] +base64.workspace = true +chrono = { version = "0.4", features = ["serde"] } +defguard-client-common = { path = "../common" } +defguard-client-proto = { path = "../client-proto" } +defguard_wireguard_rs.workspace = true +dirs-next.workspace = true +log.workspace = true +os_info = { version = "3.14", default-features = false } +prost.workspace = true +reqwest.workspace = true +rust-ini = "0.21" +semver.workspace = true +serde.workspace = true +serde_json.workspace = true +serde_with.workspace = true +sqlx.workspace = true +strum = { version = "0.28", features = ["derive"] } +struct-patch = "0.12" +thiserror.workspace = true +tokio.workspace = true +tonic.workspace = true +tower = "0.5" +x25519-dalek = { version = "2", features = ["getrandom", "serde", "static_secrets"] } + +hyper-util = "0.1" + +[target.'cfg(unix)'.dependencies] +nix = { version = "0.31", features = ["user", "fs"] } + +[target.'cfg(windows)'.dependencies] +windows-sys = { version = "0.61", features = ["Win32_Foundation"] } + +[target.'cfg(target_os = "macos")'.dependencies] +block2 = "0.6" +objc2 = "0.6" +objc2-app-kit = "0.3" +objc2-foundation = "0.3" +objc2-network-extension = "0.3" diff --git a/src-tauri/src/app_config.rs b/src-tauri/core/src/app_config.rs similarity index 81% rename from src-tauri/src/app_config.rs rename to src-tauri/core/src/app_config.rs index 8c067573..5836b5e0 100644 --- a/src-tauri/src/app_config.rs +++ b/src-tauri/core/src/app_config.rs @@ -1,23 +1,19 @@ use std::{ fs::{create_dir_all, File, OpenOptions}, - path::PathBuf, + path::Path, }; use log::LevelFilter; use serde::{Deserialize, Serialize}; use struct_patch::Patch; -use tauri::{AppHandle, Manager}; #[cfg(unix)] use crate::set_perms; static APP_CONFIG_FILE_NAME: &str = "config.json"; -fn get_config_file_path(app: &AppHandle) -> PathBuf { - let mut config_file_path = app - .path() - .app_data_dir() - .expect("Failed to access app data"); +fn get_config_file_path(config_dir: &Path) -> std::path::PathBuf { + let mut config_file_path = config_dir.to_path_buf(); if !config_file_path.exists() { create_dir_all(&config_file_path).expect("Failed to create missing app data dir"); } @@ -29,8 +25,8 @@ fn get_config_file_path(app: &AppHandle) -> PathBuf { config_file_path } -fn get_config_file(app: &AppHandle, for_write: bool) -> File { - let config_file_path = get_config_file_path(app); +fn get_config_file(config_dir: &Path, for_write: bool) -> File { + let config_file_path = get_config_file_path(config_dir); OpenOptions::new() .create(true) .read(true) @@ -76,20 +72,20 @@ impl Default for AppConfig { } impl AppConfig { - /// Try to load application configuration from application data directory. + /// Try to load application configuration from the given config directory. /// If reading the configuration file fails, default settings will be returned. #[must_use] - pub fn new(app: &AppHandle) -> Self { - let config_path = get_config_file_path(app); + pub fn new(config_dir: &Path) -> Self { + let config_path = get_config_file_path(config_dir); if !config_path.exists() { eprintln!( "Application configuration file doesn't exist; initializing it with the defaults." ); let res = Self::default(); - res.save(app); + res.save(config_dir); return res; } - let config_file = get_config_file(app, false); + let config_file = get_config_file(config_dir, false); let mut app_config = Self::default(); match serde_json::from_reader::<_, AppConfigPatch>(config_file) { Ok(patch) => { @@ -100,16 +96,16 @@ impl AppConfig { eprintln!( "Failed to deserialize application configuration file: {err}. Using defaults." ); - app_config.save(app); + app_config.save(config_dir); } } app_config } - /// Saves currently loaded AppConfig into app data dir file. + /// Saves currently loaded AppConfig into the given config directory file. /// Warning: this will always overwrite file contents. - pub fn save(&self, app: &AppHandle) { - let file = get_config_file(app, true); + pub fn save(&self, config_dir: &Path) { + let file = get_config_file(config_dir, true); match serde_json::to_writer(file, &self) { Ok(()) => debug!("Application configuration file has been saved."), Err(err) => { diff --git a/src-tauri/core/src/connection/active_connections.rs b/src-tauri/core/src/connection/active_connections.rs new file mode 100644 index 00000000..0feaa895 --- /dev/null +++ b/src-tauri/core/src/connection/active_connections.rs @@ -0,0 +1,89 @@ +use std::{collections::HashSet, sync::LazyLock}; + +use tokio::sync::Mutex; + +use crate::{ + connection::disconnect_interface, + database::{ + models::{connection::ActiveConnection, instance::Instance, location::Location, Id}, + DB_POOL, + }, + error::Error, + ConnectionType, +}; + +pub static ACTIVE_CONNECTIONS: LazyLock>> = + LazyLock::new(|| Mutex::new(Vec::new())); + +pub async fn get_connection_id_by_type(connection_type: ConnectionType) -> Vec { + let active_connections = ACTIVE_CONNECTIONS.lock().await; + + active_connections + .iter() + .filter_map(|con| { + if con.connection_type == connection_type { + Some(con.location_id) + } else { + None + } + }) + .collect() +} + +pub async fn close_all_connections() -> Result<(), Error> { + debug!("Closing all active connections"); + let active_connections = ACTIVE_CONNECTIONS.lock().await; + let active_connections_count = active_connections.len(); + debug!("Found {active_connections_count} active connections"); + for connection in active_connections.iter() { + debug!( + "Found active connection with location {}", + connection.location_id + ); + trace!("Connection: {connection:#?}"); + debug!("Removing interface {}", connection.interface_name); + disconnect_interface(connection).await?; + } + if active_connections_count > 0 { + info!("All active connections ({active_connections_count}) have been closed."); + } else { + debug!("There were no active connections to close, nothing to do."); + } + Ok(()) +} + +pub async fn find_connection(id: Id, connection_type: ConnectionType) -> Option { + let connections = ACTIVE_CONNECTIONS.lock().await; + trace!( + "Checking for active connection with ID {id}, type {connection_type} in active connections." + ); + + if let Some(connection) = connections + .iter() + .find(|conn| conn.location_id == id && conn.connection_type == connection_type) + { + trace!("Found connection: {connection:?}"); + Some(connection.to_owned()) + } else { + debug!( + "Couldn't find connection with ID {id}, type: {connection_type} in active connections." + ); + None + } +} + +/// Returns active connections for a given instance. +pub async fn active_connections(instance: &Instance) -> Result, Error> { + let locations: HashSet = Location::find_by_instance_id(&*DB_POOL, instance.id, false) + .await? + .iter() + .map(|location| location.id) + .collect(); + Ok(ACTIVE_CONNECTIONS + .lock() + .await + .iter() + .filter(|connection| locations.contains(&connection.location_id)) + .cloned() + .collect()) +} diff --git a/src-tauri/core/src/connection/apple.rs b/src-tauri/core/src/connection/apple.rs new file mode 100644 index 00000000..373c9607 --- /dev/null +++ b/src-tauri/core/src/connection/apple.rs @@ -0,0 +1,967 @@ +//! Interchangeability and communication with VPNExtension (written in Swift). + +use std::{ + collections::HashMap, + hint::spin_loop, + net::IpAddr, + ptr::NonNull, + str::FromStr, + sync::{ + atomic::{AtomicBool, Ordering}, + mpsc::{self, channel, Receiver, RecvTimeoutError, Sender}, + Arc, LazyLock, Mutex, + }, + time::Duration, +}; + +use block2::RcBlock; +use defguard_client_common::dns_owned; +use defguard_wireguard_rs::{key::Key, net::IpAddrMask, peer::Peer}; +use objc2::{ + rc::Retained, + runtime::{AnyObject, ProtocolObject}, +}; +use objc2_foundation::{ + ns_string, NSArray, NSData, NSDate, NSDictionary, NSError, NSMutableArray, NSMutableDictionary, + NSNotification, NSNotificationCenter, NSNumber, NSObjectProtocol, NSOperationQueue, NSRunLoop, + NSString, +}; +use objc2_network_extension::{ + NETunnelProviderManager, NETunnelProviderProtocol, NETunnelProviderSession, NEVPNConnection, + NEVPNStatus, NEVPNStatusDidChangeNotification, +}; +use serde::Deserialize; +use tracing::Level; + +use crate::{ + database::{ + models::{ + instance::{ClientTrafficPolicy, Instance}, + location::Location, + tunnel::Tunnel, + wireguard_keys::WireguardKeys, + Id, + }, + DB_POOL, + }, + error::Error, + ConnectionType, DEFAULT_ROUTE_IPV4, DEFAULT_ROUTE_IPV6, +}; + +const PLUGIN_BUNDLE_ID: &str = "net.defguard.VPNExtension"; +const SYSTEM_SYNC_DELAY: Duration = Duration::from_millis(500); +const LOCATION_ID: &str = "locationId"; +const TUNNEL_ID: &str = "tunnelId"; + +type ObserverSender = Mutex>; +type ObserverReceiver = Mutex>>; + +static OBSERVER_COMMS: LazyLock<(ObserverSender, ObserverReceiver)> = LazyLock::new(|| { + let (tx, rx) = mpsc::channel(); + (Mutex::new(tx), Mutex::new(Some(rx))) +}); + +type VpnStateSender = Mutex>; +type VpnStateReceiver = Mutex>>; + +static VPN_STATE_UPDATE_COMMS: LazyLock<(VpnStateSender, VpnStateReceiver)> = LazyLock::new(|| { + let (tx, rx) = mpsc::channel(); + (Mutex::new(tx), Mutex::new(Some(rx))) +}); + +/// Thread responsible for observing VPN status changes. +/// This is intentionally a blocking function, as it uses the Objective-C objects which are not +/// thread safe. +pub fn observer_thread( + initial_managers: HashMap<(&'static str, Id), Retained>, +) { + debug!("Starting VPN connection observer thread"); + let receiver = { + let mut rx_opt = OBSERVER_COMMS + .1 + .lock() + .expect("Failed to lock observer receiver"); + rx_opt.take().expect("Receiver already taken") + }; + + let mut observers = HashMap::new(); + + // spawn initial observers for existing managers + for ((key, value), manager) in initial_managers { + debug!("Spawning initial observer for manager with key: {key}, value: {value}"); + let connection = unsafe { manager.connection() }; + let observer = create_observer(&connection); + debug!("Registered initial observer for manager with key: {key}, value: {value}"); + observers.insert((key, value), observer); + } + + loop { + match receiver.recv_timeout(OBSERVER_CLEANUP_INTERVAL) { + Ok(message) => { + debug!("Received message to observe the following connection: {message:?}"); + + let (key, value) = message; + + if observers.contains_key(&(key, value)) { + debug!( + "Observer for manager with key: {key}, value: {value} already exists, + skipping", + ); + continue; + } + + let manager = manager_for_key_and_value(key, value).unwrap(); + let connection = unsafe { manager.connection() }; + let observer = create_observer(&connection); + + observers.insert((key, value), observer); + debug!("Registered observer for manager with key: {key}, value: {value}"); + } + Err(RecvTimeoutError::Timeout) => { + debug!("Performing periodic cleanup of dead observers"); + let mut dead_keys = Vec::new(); + + for (key, value) in observers.keys() { + if manager_for_key_and_value(key, *value).is_none() { + debug!( + "Manager for key: {key}, value: {value} no longer exists, marking for + removal" + ); + dead_keys.push((*key, *value)); + } + } + + for dead_key in dead_keys { + if let Some(_observer) = observers.remove(&dead_key) { + debug!( + "Removed dead VPN connection observer for key: {}, value: {}", + dead_key.0, dead_key.1 + ); + } + } + } + Err(RecvTimeoutError::Disconnected) => { + error!("Observer receiver channel disconnected, exiting observer thread"); + break; + } + } + } + + debug!("Exiting VPN connection observer thread"); +} + +/// Tunnel statistics shared with VPNExtension (written in Swift). +#[derive(Deserialize)] +#[repr(C)] +#[serde(rename_all = "camelCase")] +pub(crate) struct Stats { + pub(crate) location_id: Option, + pub(crate) tunnel_id: Option, + pub(crate) tx_bytes: u64, + pub(crate) rx_bytes: u64, + pub(crate) last_handshake: u64, +} + +/// Run [`NSRunLoop`] until semaphore becomes `true`. +pub fn spawn_runloop_and_wait_for(semaphore: &Arc) { + const ONE_SECOND: f64 = 1.; + let run_loop = NSRunLoop::currentRunLoop(); + let mut date = NSDate::dateWithTimeIntervalSinceNow(ONE_SECOND); + loop { + run_loop.runUntilDate(&date); + if semaphore.load(Ordering::Acquire) { + break; + } + date = date.dateByAddingTimeInterval(ONE_SECOND); + } +} + +/// Handle VPN status change. +fn vpn_status_change_handler(notification: &NSNotification) { + let name = notification.name(); + debug!("Received VPN status change notification: {name:?}"); + VPN_STATE_UPDATE_COMMS + .0 + .lock() + .expect("Failed to lock state update sender") + .send(()) + .expect("Failed to send to state update channel"); + debug!("Sent status update request to channel"); +} + +/// Observe VPN status change. +fn create_observer(object: &NEVPNConnection) -> Retained> { + let center = NSNotificationCenter::defaultCenter(); + let block = RcBlock::new(move |notification: NonNull| { + vpn_status_change_handler(unsafe { notification.as_ref() }); + }); + let queue = NSOperationQueue::mainQueue(); + unsafe { + let name = NEVPNStatusDidChangeNotification; + center.addObserverForName_object_queue_usingBlock( + Some(name), + Some(object), + Some(&queue), + &block, + ) + } +} + +#[must_use] +pub fn get_managers_for_tunnels_and_locations( + tunnels: &[Tunnel], + locations: &[Location], +) -> HashMap<(&'static str, Id), Retained> { + let mut managers = HashMap::new(); + + for location in locations { + if let Some(manager) = manager_for_key_and_value(LOCATION_ID, location.id) { + managers.insert((LOCATION_ID, location.id), manager); + } + } + + for tunnel in tunnels { + if let Some(manager) = manager_for_key_and_value(TUNNEL_ID, tunnel.id) { + managers.insert((TUNNEL_ID, tunnel.id), manager); + } + } + + managers +} + +/// Try to get `Id` out of manager. ID is embedded in configuration dictionary under `key`. +fn id_from_manager(manager: &NETunnelProviderManager, key: &NSString) -> Option { + let plugin_bundle_id = ns_string!(PLUGIN_BUNDLE_ID); + + let vpn_protocol = (unsafe { manager.protocolConfiguration() })?; + let Ok(tunnel_protocol) = vpn_protocol.downcast::() else { + error!("Failed to downcast to NETunnelProviderProtocol"); + return None; + }; + // Sometimes all managers from all apps come through, so filter by bundle ID. + if let Some(bundle_id) = unsafe { tunnel_protocol.providerBundleIdentifier() } { + if &*bundle_id != plugin_bundle_id { + return None; + } + } + + if let Some(config_dict) = unsafe { tunnel_protocol.providerConfiguration() } { + if let Some(any_object) = config_dict.objectForKey(key) { + let Ok(id) = any_object.downcast::() else { + warn!("Failed to downcast ID to NSNumber"); + return None; + }; + return Some(id.as_i64()); + } + } + + None +} + +/// Try to find [`NETunnelProviderManager`] in system settings that matches key and value. +/// Key is usually `locationId` or `tunnelId`. +fn manager_for_key_and_value(key: &str, value: Id) -> Option> { + let key_string = NSString::from_str(key); + let (tx, rx) = channel(); + + let handler = RcBlock::new( + move |managers_ptr: *mut NSArray, error_ptr: *mut NSError| { + if !error_ptr.is_null() { + error!("Failed to load tunnel provider managers."); + return; + } + + let Some(managers) = (unsafe { managers_ptr.as_ref() }) else { + error!("No managers"); + return; + }; + + for manager in managers { + if let Some(id) = id_from_manager(&manager, &key_string) { + if id == value { + // This is the manager we were looking for. + tx.send(Some(manager)).expect("Sender is dead"); + return; + } + } + } + + tx.send(None).expect("Sender is dead"); + }, + ); + unsafe { + NETunnelProviderManager::loadAllFromPreferencesWithCompletionHandler(&handler); + } + + rx.recv().expect("Receiver is dead") +} + +/// Tunnel configuration shared with VPNExtension (written in Swift). +pub(crate) struct TunnelConfiguration { + location_id: Option, + tunnel_id: Option, + name: String, + private_key: String, + addresses: Vec, + listen_port: Option, + peers: Vec, + mtu: Option, + dns: Vec, + dns_search: Vec, +} + +impl TunnelConfiguration { + /// Convert to [`NSDictionary`]. + fn as_nsdict(&self) -> Retained> { + let dict = NSMutableDictionary::new(); + + if let Some(location_id) = self.location_id { + dict.insert( + ns_string!(LOCATION_ID), + NSNumber::new_i64(location_id).as_ref(), + ); + } + + if let Some(tunnel_id) = self.tunnel_id { + dict.insert(ns_string!(TUNNEL_ID), NSNumber::new_i64(tunnel_id).as_ref()); + } + + dict.insert(ns_string!("name"), NSString::from_str(&self.name).as_ref()); + + dict.insert( + ns_string!("privateKey"), + NSString::from_str(&self.private_key).as_ref(), + ); + + // IpAddrMask + let addresses = NSMutableArray::>::new(); + for addr in &self.addresses { + let addr_dict = NSMutableDictionary::::new(); + addr_dict.insert( + ns_string!("address"), + NSString::from_str(&addr.address.to_string()).as_ref(), + ); + addr_dict.insert(ns_string!("cidr"), NSNumber::new_u8(addr.cidr).as_ref()); + addresses.addObject(addr_dict.into_super().as_ref()); + } + dict.insert(ns_string!("addresses"), addresses.as_ref()); + + if let Some(listen_port) = self.listen_port { + dict.insert( + ns_string!("listenPort"), + NSNumber::new_u16(listen_port).as_ref(), + ); + } + + // Peer + let peers = NSMutableArray::>::new(); + for peer in &self.peers { + let peer_dict = NSMutableDictionary::::new(); + peer_dict.insert( + ns_string!("publicKey"), + NSString::from_str(&peer.public_key.to_string()).as_ref(), + ); + + if let Some(preshared_key) = &peer.preshared_key { + peer_dict.insert( + ns_string!("preSharedKey"), + NSString::from_str(&preshared_key.to_string()).as_ref(), + ); + } + + if let Some(endpoint) = &peer.endpoint { + peer_dict.insert( + ns_string!("endpoint"), + NSString::from_str(&endpoint.to_string()).as_ref(), + ); + } + + // Skipping: lastHandshake, txBytes, rxBytes. + + if let Some(persistent_keep_alive) = peer.persistent_keepalive_interval { + peer_dict.insert( + ns_string!("persistentKeepAlive"), + NSNumber::new_u16(persistent_keep_alive).as_ref(), + ); + } + + // IpAddrMask + let allowed_ips = NSMutableArray::>::new(); + for addr in &peer.allowed_ips { + let addr_dict = NSMutableDictionary::::new(); + addr_dict.insert( + ns_string!("address"), + NSString::from_str(&addr.address.to_string()).as_ref(), + ); + addr_dict.insert(ns_string!("cidr"), NSNumber::new_u8(addr.cidr).as_ref()); + allowed_ips.addObject(addr_dict.into_super().as_ref()); + } + peer_dict.insert(ns_string!("allowedIPs"), allowed_ips.as_ref()); + + peers.addObject(peer_dict.into_super().as_ref()); + } + dict.insert(ns_string!("peers"), peers.into_super().as_ref()); + + if let Some(mtu) = self.mtu { + dict.insert(ns_string!("mtu"), NSNumber::new_u32(mtu).as_ref()); + } + + let dns = NSMutableArray::::new(); + for entry in &self.dns { + dns.addObject(NSString::from_str(&entry.to_string()).as_ref()); + } + dict.insert(ns_string!("dns"), dns.as_ref()); + + let dns_search = NSMutableArray::::new(); + for entry in &self.dns_search { + dns_search.addObject(NSString::from_str(entry).as_ref()); + } + dict.insert(ns_string!("dnsSearch"), dns_search.as_ref()); + + dict.into_super() + } + + /// Try to find `NETunnelProviderManager` for this configuration, based on location ID or + /// tunnel ID. + pub(crate) fn tunnel_provider_manager(&self) -> Option> { + let (key, value) = match (self.location_id, self.tunnel_id) { + (Some(location_id), None) => (LOCATION_ID, location_id), + (None, Some(tunnel_id)) => (TUNNEL_ID, tunnel_id), + _ => return None, + }; + + manager_for_key_and_value(key, value) + } + + /// Create or update system VPN settings with this configuration. + pub(crate) fn save(&self) { + let spinlock = Arc::new(AtomicBool::new(false)); + let spinlock_clone = Arc::clone(&spinlock); + let plugin_bundle_id = ns_string!(PLUGIN_BUNDLE_ID); + + let provider_manager = self + .tunnel_provider_manager() + .unwrap_or_else(|| unsafe { NETunnelProviderManager::new() }); + + unsafe { + let tunnel_protocol = NETunnelProviderProtocol::new(); + tunnel_protocol.setProviderBundleIdentifier(Some(plugin_bundle_id)); + let server_address = self.peers.first().map_or(String::new(), |peer| { + peer.endpoint.map_or(String::new(), |sa| sa.to_string()) + }); + let server_address = NSString::from_str(&server_address); + // `serverAddress` must have a non-nil string value for the protocol configuration to be + // valid. + tunnel_protocol.setServerAddress(Some(&server_address)); + + let provider_config = self.as_nsdict(); + tunnel_protocol.setProviderConfiguration(Some(&*provider_config)); + + provider_manager.setProtocolConfiguration(Some(&tunnel_protocol)); + let name = NSString::from_str(&self.name); + provider_manager.setLocalizedDescription(Some(&name)); + provider_manager.setEnabled(true); + + // Save to system settings. + let handler = RcBlock::new(move |error_ptr: *mut NSError| { + if error_ptr.is_null() { + debug!("Saved tunnel configuration for {name} to system settings"); + } else { + error!("Failed to save tunnel configuration for: {name} to system settings"); + } + spinlock_clone.store(true, Ordering::Release); + }); + provider_manager.saveToPreferencesWithCompletionHandler(Some(&*handler)); + } + + while !spinlock.load(Ordering::Acquire) { + spin_loop(); + } + } + + /// Start tunnel for this configuration. + pub(crate) fn start_tunnel(&self) { + if let Some(provider_manager) = self.tunnel_provider_manager() { + if let Err(err) = + unsafe { provider_manager.connection().startVPNTunnelAndReturnError() } + { + error!("Failed to start VPN: {err}"); + } else { + OBSERVER_COMMS + .0 + .lock() + .expect("Failed to lock observer sender") + .send(( + self.location_id + .map_or_else(|| TUNNEL_ID, |_location_id| LOCATION_ID), + self.location_id.or(self.tunnel_id).unwrap(), + )) + .expect("Failed to send to observer channel"); + info!("VPN started"); + } + } else { + debug!( + "Couldn't find configuration from system settings for {}", + self.name + ); + } + } +} + +/// Retrieve VPN tunnel statistics from VPNExtension. +pub(crate) fn tunnel_stats(id: Id, connection_type: &ConnectionType) -> Option { + let new_stats = Arc::new(Mutex::new(None)); + let plugin_bundle_id = ns_string!(PLUGIN_BUNDLE_ID); + + let new_stats_clone = Arc::clone(&new_stats); + + let finished = Arc::new(AtomicBool::new(false)); + let finished_clone = Arc::clone(&finished); + + let response_handler = RcBlock::new(move |data_ptr: *mut NSData| { + if let Some(data) = unsafe { data_ptr.as_ref() } { + if let Ok(stats) = serde_json::from_slice(data.to_vec().as_slice()) { + if let Ok(mut new_stats_locked) = new_stats_clone.lock() { + *new_stats_locked = Some(stats); + } + } else { + warn!("Failed to deserialize tunnel stats"); + } + } else { + debug!("No data received in tunnel stats response, skipping"); + } + finished_clone.store(true, Ordering::Release); + }); + + let manager = manager_for_key_and_value( + match connection_type { + ConnectionType::Location => LOCATION_ID, + ConnectionType::Tunnel => TUNNEL_ID, + }, + id, + )?; + + let vpn_protocol = (unsafe { manager.protocolConfiguration() })?; + let Ok(tunnel_protocol) = vpn_protocol.downcast::() else { + error!("Failed to downcast to NETunnelProviderProtocol"); + return None; + }; + + // Sometimes all managers from all apps come through, so filter by bundle ID. + if let Some(bundle_id) = unsafe { tunnel_protocol.providerBundleIdentifier() } { + if &*bundle_id != plugin_bundle_id { + return None; + } + } + + let Ok(session) = unsafe { manager.connection() }.downcast::() else { + error!("Failed to downcast to NETunnelProviderSession"); + return None; + }; + + let message_data = NSData::new(); + if unsafe { + session.sendProviderMessage_returnError_responseHandler( + &message_data, + None, + Some(&response_handler), + ) + } { + debug!("Message sent to NETunnelProviderSession"); + } else { + error!("Failed to send to NETunnelProviderSession while requesting stats"); + } + + // Wait for all handlers to complete. + while !finished.load(Ordering::Acquire) { + spin_loop(); + } + + let stats = new_stats + .lock() + .map_or(None, |mut new_stats_locked| new_stats_locked.take()); + + stats +} + +/// Synchronize locations and tunnels with system settings. +pub async fn sync_locations_and_tunnels(mtu: Option) -> Result<(), sqlx::Error> { + // Update location settings. + let all_locations = Location::all(&*DB_POOL, false).await?; + for location in &all_locations { + // For syncing, set `preshred_key` to `None`. + let Ok(tunnel_config) = location.tunnel_configurarion(None, mtu).await else { + error!( + "Failed to convert location {} to tunnel configuration.", + location.name + ); + continue; + }; + tunnel_config.save(); + } + + // Update tunnel settings. + let all_tunnels = Tunnel::all(&*DB_POOL).await?; + for tunnel in &all_tunnels { + let Ok(tunnel_config) = tunnel.tunnel_configurarion(mtu) else { + error!( + "Failed to convert tunnel {} to tunnel configuration.", + tunnel.name + ); + continue; + }; + tunnel_config.save(); + } + + debug!("Saved all configurations with system settings."); + + // Convert to Vec. + let mut all_location_ids = all_locations + .into_iter() + .map(|entry| entry.id) + .collect::>(); + let mut all_tunnel_ids = all_tunnels + .into_iter() + .map(|entry| entry.id) + .collect::>(); + // For faster lookup using binary search (see below). + all_location_ids.sort_unstable(); + all_tunnel_ids.sort_unstable(); + + let spinlock = Arc::new(AtomicBool::new(false)); + let spinlock_clone = Arc::clone(&spinlock); + let handler = RcBlock::new( + move |managers_ptr: *mut NSArray, error_ptr: *mut NSError| { + if !error_ptr.is_null() { + error!("Failed to load tunnel provider managers."); + return; + } + + let Some(managers) = (unsafe { managers_ptr.as_ref() }) else { + error!("No managers"); + return; + }; + + let location_key = NSString::from_str(LOCATION_ID); + let tunnel_key = NSString::from_str(TUNNEL_ID); + for manager in managers { + if let Some(id) = id_from_manager(&manager, &location_key) { + if all_location_ids.binary_search(&id).is_ok() { + // Known location - skip. + continue; + } + } + if let Some(id) = id_from_manager(&manager, &tunnel_key) { + if all_tunnel_ids.binary_search(&id).is_ok() { + // Known tunnel - skip. + continue; + } + } + unsafe { manager.removeFromPreferencesWithCompletionHandler(None) }; + } + + spinlock_clone.store(true, Ordering::Release); + }, + ); + unsafe { + NETunnelProviderManager::loadAllFromPreferencesWithCompletionHandler(&handler); + } + + while !spinlock.load(Ordering::Acquire) { + spin_loop(); + } + + debug!("Removed unknown configurations from system settings."); + + Ok(()) +} + +impl Location { + /// Build [`TunnelConfiguration`] from [`Location`]. + pub(crate) async fn tunnel_configurarion( + &self, + preshared_key: Option, + mtu: Option, + ) -> Result { + debug!("Looking for WireGuard keys for location {self} instance"); + let Some(keys) = WireguardKeys::find_by_instance_id(&*DB_POOL, self.instance_id).await? + else { + error!("No keys found for instance: {}", self.instance_id); + return Err(Error::InternalError( + "No keys found for instance".to_string(), + )); + }; + debug!("WireGuard keys found for location {self} instance"); + + // prepare peer config + debug!("Decoding location {self} public key: {}.", self.pubkey); + let peer_key = Key::from_str(&self.pubkey)?; + debug!("Location {self} public key decoded: {peer_key}"); + let mut peer = Peer::new(peer_key); + + debug!("Parsing location {self} endpoint: {}", self.endpoint); + peer.set_endpoint(&self.endpoint)?; + peer.persistent_keepalive_interval = Some(25); + debug!("Parsed location {self} endpoint: {}", self.endpoint); + + if let Some(psk) = preshared_key { + debug!("Decoding location {self} preshared key."); + let peer_psk = Key::from_str(&psk)?; + info!("Location {self} preshared key decoded."); + peer.preshared_key = Some(peer_psk); + } + + debug!("Parsing location {self} allowed IPs: {}", self.allowed_ips); + let Some(instance) = Instance::find_by_id(&*DB_POOL, self.instance_id).await? else { + error!("Instance {} not found", self.instance_id); + return Err(Error::InternalError(format!( + "Instance {} not found", + self.instance_id + ))); + }; + let route_all_traffic = match instance.client_traffic_policy { + ClientTrafficPolicy::ForceAllTraffic => true, + ClientTrafficPolicy::DisableAllTraffic => false, + ClientTrafficPolicy::None => self.route_all_traffic, + }; + let allowed_ips = if route_all_traffic { + debug!("Using all traffic routing for location {self}"); + vec![DEFAULT_ROUTE_IPV4.into(), DEFAULT_ROUTE_IPV6.into()] + } else { + debug!( + "Using predefined location {self} traffic: {}", + self.allowed_ips + ); + self.allowed_ips.split(',').map(str::to_string).collect() + }; + for allowed_ip in &allowed_ips { + match IpAddrMask::from_str(allowed_ip) { + Ok(addr) => { + peer.allowed_ips.push(addr); + } + Err(err) => { + // Handle the error from IpAddrMask::from_str, if needed + error!( + "Error parsing IP address {allowed_ip} while setting up interface for \ + location {self}, error details: {err}" + ); + } + } + } + debug!( + "Parsed allowed IPs for location {self}: {:?}", + peer.allowed_ips + ); + + let addresses = self + .address + .split(',') + .map(str::trim) + .map(IpAddrMask::from_str) + .collect::>() + .map_err(|err| { + let msg = format!("Failed to parse IP addresses '{}': {err}", self.address); + error!("{msg}"); + Error::InternalError(msg) + })?; + let (dns, dns_search) = dns_owned(&self.dns); + Ok(TunnelConfiguration { + location_id: Some(self.id), + tunnel_id: None, + name: self.name.clone(), + private_key: keys.prvkey, + addresses, + listen_port: Some(0), + peers: vec![peer], + mtu, + dns, + dns_search, + }) + } + + /// Check whether VPN tunnel is running for [`Location`]. + pub(crate) fn status(&self) -> Option { + manager_for_key_and_value(LOCATION_ID, self.id).map_or_else( + || { + debug!( + "Couldn't find configuration in system settings for location {}", + self.name + ); + None + }, + |provider_manager| unsafe { + let connection = provider_manager.connection(); + Some(connection.status()) + }, + ) + } + + /// Remove configuration from system settings for [`Location`]. + pub(crate) fn remove_config(&self) { + if let Some(provider_manager) = manager_for_key_and_value(LOCATION_ID, self.id) { + unsafe { + provider_manager.removeFromPreferencesWithCompletionHandler(None); + } + } else { + debug!( + "Couldn't find configuration in system settings for location {}", + self.name + ); + } + } + + /// Stop VPN tunnel for [`Location`]. + pub(crate) fn stop_vpn_tunnel(&self) -> bool { + manager_for_key_and_value(LOCATION_ID, self.id).map_or_else( + || { + debug!( + "Couldn't find configuration in system settings for location {}", + self.name + ); + false + }, + |provider_manager| { + unsafe { + provider_manager.connection().stopVPNTunnel(); + } + info!("VPN stopped"); + true + }, + ) + } +} + +impl Tunnel { + /// Build [`TunnelConfiguration`] from [`Tunnel`]. + pub(crate) fn tunnel_configurarion( + &self, + mtu: Option, + ) -> Result { + // prepare peer config + debug!("Decoding tunnel {self} public key: {}.", self.server_pubkey); + let peer_key = Key::from_str(&self.server_pubkey)?; + debug!("Tunnel {self} public key decoded."); + let mut peer = Peer::new(peer_key); + + debug!("Parsing tunnel {self} endpoint: {}", self.endpoint); + peer.set_endpoint(&self.endpoint)?; + peer.persistent_keepalive_interval = Some( + self.persistent_keep_alive + .try_into() + .expect("Failed to parse persistent keep alive"), + ); + debug!("Parsed tunnel {self} endpoint: {}", self.endpoint); + + if let Some(psk) = &self.preshared_key { + debug!("Decoding tunnel {self} preshared key."); + let peer_psk = Key::from_str(psk)?; + debug!("Preshared key for tunnel {self} decoded."); + peer.preshared_key = Some(peer_psk); + } + + debug!("Parsing tunnel {self} allowed ips: {:?}", self.allowed_ips); + let allowed_ips = if self.route_all_traffic { + debug!("Using all traffic routing for tunnel {self}"); + vec![DEFAULT_ROUTE_IPV4.into(), DEFAULT_ROUTE_IPV6.into()] + } else { + let msg = self.allowed_ips.as_ref().map_or_else( + || "No allowed IP addresses found in tunnel {self} configuration".to_string(), + |ips| format!("Using predefined location traffic for tunnel {self}: {ips}"), + ); + debug!("{msg}"); + self.allowed_ips + .as_ref() + .map(|ips| ips.split(',').map(str::to_string).collect()) + .unwrap_or_default() + }; + for allowed_ip in &allowed_ips { + match IpAddrMask::from_str(allowed_ip.trim()) { + Ok(addr) => { + peer.allowed_ips.push(addr); + } + Err(err) => { + // Handle the error from IpAddrMask::from_str, if needed + error!("Error parsing IP address {allowed_ip}: {err}"); + // Continue to the next iteration of the loop + } + } + } + debug!("Parsed tunnel {self} allowed IPs: {:?}", peer.allowed_ips); + + let addresses = self + .address + .split(',') + .map(str::trim) + .map(IpAddrMask::from_str) + .collect::>() + .map_err(|err| { + let msg = format!("Failed to parse IP addresses '{}': {err}", self.address); + error!("{msg}"); + Error::InternalError(msg) + })?; + let (dns, dns_search) = dns_owned(&self.dns); + Ok(TunnelConfiguration { + location_id: None, + tunnel_id: Some(self.id), + name: self.name.clone(), + private_key: self.prvkey.clone(), + addresses, + listen_port: Some(0), + peers: vec![peer], + mtu, + dns, + dns_search, + }) + } + + /// Check whether VPN tunnel is running for [`Tunnel`]. + pub(crate) fn status(&self) -> Option { + manager_for_key_and_value(TUNNEL_ID, self.id).map_or_else( + || { + debug!( + "Couldn't find configuration in system settings for tunnel {}", + self.name + ); + None + }, + |provider_manager| unsafe { + let connection = provider_manager.connection(); + Some(connection.status()) + }, + ) + } + + /// Remove configuration from system settings for [`Tunnel`]. + pub(crate) fn remove_config(&self) { + if let Some(provider_manager) = manager_for_key_and_value(TUNNEL_ID, self.id) { + unsafe { + provider_manager.removeFromPreferencesWithCompletionHandler(None); + } + } else { + debug!( + "Couldn't find configuration in system settings for tunnel {}", + self.name + ); + } + } + + /// Stop tunnel for [`Tunnel`]. + pub(crate) fn stop_vpn_tunnel(&self) -> bool { + manager_for_key_and_value(TUNNEL_ID, self.id).map_or_else( + || { + debug!( + "Couldn't find configuration in system settings for location {}", + self.name + ); + false + }, + |provider_manager| { + unsafe { + provider_manager.connection().stopVPNTunnel(); + } + info!("VPN stopped"); + true + }, + ) + } +} diff --git a/src-tauri/core/src/connection/daemon_client.rs b/src-tauri/core/src/connection/daemon_client.rs new file mode 100644 index 00000000..21f2f42a --- /dev/null +++ b/src-tauri/core/src/connection/daemon_client.rs @@ -0,0 +1,67 @@ +use std::sync::LazyLock; + +use hyper_util::rt::TokioIo; +#[cfg(windows)] +use tokio::net::windows::named_pipe::ClientOptions; +#[cfg(unix)] +use tokio::net::UnixStream; +use tonic::transport::channel::{Channel, Endpoint}; +#[cfg(unix)] +use tonic::transport::Uri; +use tower::service_fn; +#[cfg(windows)] +use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY; + +use defguard_client_proto::defguard::client::v1::desktop_daemon_service_client::DesktopDaemonServiceClient; + +#[cfg(unix)] +const DAEMON_SOCKET_PATH: &str = "/var/run/defguard.socket"; +#[cfg(windows)] +const PIPE_NAME: &str = r"\\.\pipe\defguard_daemon"; + +pub static DAEMON_CLIENT: LazyLock> = LazyLock::new(|| { + log::debug!("Setting up gRPC client"); + let endpoint = Endpoint::from_static("http://localhost"); + let channel; + #[cfg(unix)] + { + channel = endpoint.connect_with_connector_lazy(service_fn(|_: Uri| async { + let stream = match UnixStream::connect(DAEMON_SOCKET_PATH).await { + Ok(stream) => stream, + Err(err) if err.kind() == std::io::ErrorKind::PermissionDenied => { + log::error!( + "Permission denied for UNIX domain socket; please refer to \ + https://docs.defguard.net/support-1/troubleshooting#\ + unix-socket-permission-errors-when-desktop-client-attempts-to-connect-\ + to-vpn-on-linux-machines" + ); + return Err(err); + } + Err(err) => { + log::error!("Problem connecting to UNIX domain socket: {err}"); + return Err(err); + } + }; + log::info!("Created unix gRPC client"); + Ok::<_, std::io::Error>(TokioIo::new(stream)) + })); + }; + #[cfg(windows)] + { + channel = endpoint.connect_with_connector_lazy(service_fn(|_| async { + let client = loop { + match ClientOptions::new().open(PIPE_NAME) { + Ok(client) => break client, + Err(err) if err.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), + Err(err) => { + log::error!("Problem connecting to named pipe: {err}"); + return Err(err); + } + } + }; + log::info!("Created windows gRPC client"); + Ok::<_, std::io::Error>(TokioIo::new(client)) + })); + } + DesktopDaemonServiceClient::new(channel) +}); diff --git a/src-tauri/core/src/connection/mod.rs b/src-tauri/core/src/connection/mod.rs new file mode 100644 index 00000000..4c71ec49 --- /dev/null +++ b/src-tauri/core/src/connection/mod.rs @@ -0,0 +1,15 @@ +pub mod active_connections; +pub mod daemon_client; +pub mod setup; + +#[cfg(target_os = "macos")] +pub mod apple; + +#[cfg(not(target_os = "macos"))] +pub use setup::{disconnect_interface, execute_command, setup_interface, setup_interface_tunnel}; + +#[cfg(target_os = "macos")] +pub use apple::{ + get_managers_for_tunnels_and_locations, location_tunnel_configuration, + sync_locations_and_tunnels, tunnel_stats, tunnel_tunnel_configuration, TunnelConfiguration, +}; diff --git a/src-tauri/core/src/connection/setup.rs b/src-tauri/core/src/connection/setup.rs new file mode 100644 index 00000000..be3659e4 --- /dev/null +++ b/src-tauri/core/src/connection/setup.rs @@ -0,0 +1,303 @@ +// Non-macOS connection setup helpers. + +use std::str::FromStr; + +use std::process::Command; + +use defguard_client_common::{find_free_tcp_port, get_interface_name}; +use defguard_wireguard_rs::{key::Key, net::IpAddrMask, peer::Peer, InterfaceConfiguration}; +use tonic::Code; + +use crate::{ + connection::daemon_client::DAEMON_CLIENT, + database::{ + models::{connection::ActiveConnection, location::Location, tunnel::Tunnel, Id}, + DbPool, DB_POOL, + }, + error::Error, + DEFAULT_ROUTE_IPV4, DEFAULT_ROUTE_IPV6, +}; +use defguard_client_proto::defguard::client::v1::{CreateInterfaceRequest, RemoveInterfaceRequest}; + +pub async fn setup_interface( + location: &Location, + name: &str, + preshared_key: Option, + mtu: Option, + pool: &DbPool, +) -> Result { + log::debug!("Setting up interface for location: {location}"); + let interface_name = get_interface_name(name); + + log::debug!("Looking for a free port for interface {interface_name}."); + let Some(port) = find_free_tcp_port() else { + let msg = format!( + "Couldn't find free port during interface {interface_name} setup for location {location}" + ); + log::error!("{msg}"); + return Err(Error::InternalError(msg)); + }; + log::debug!("Found free port: {port} for interface {interface_name}."); + + let interface_config = location + .interface_configuration(pool, interface_name.clone(), preshared_key, mtu) + .await?; + log::debug!( + "Creating interface for location {location} with configuration {interface_config:?}" + ); + let request = CreateInterfaceRequest { + config: Some(interface_config.clone().into()), + dns: location.dns.clone(), + }; + if let Err(error) = DAEMON_CLIENT.clone().create_interface(request).await { + if error.code() == Code::Unavailable { + log::error!( + "Failed to set up connection for location {location}; background service is \ + unavailable. Make sure the service is running. Error: {error}" + ); + Err(Error::InternalError( + "Background service is unavailable. Make sure the service is running.".into(), + )) + } else { + log::error!( + "Failed to send a request to the background service to create an interface for \ + location {location}. Error: {error}" + ); + Err(Error::InternalError(format!( + "Failed to send a request to the background service to create an interface for \ + location {location}. Error: {error}. Check logs for details." + ))) + } + } else { + log::info!( + "The interface for location {location} has been created successfully, interface \ + name: {}.", + interface_config.name + ); + Ok(interface_name) + } +} + +pub async fn setup_interface_tunnel( + tunnel: &Tunnel, + name: &str, + mtu: Option, +) -> Result { + log::debug!("Setting up interface for tunnel {tunnel}"); + let interface_name = get_interface_name(name); + + log::debug!( + "Decoding tunnel {tunnel} public key: {}.", + tunnel.server_pubkey + ); + let peer_key = Key::from_str(&tunnel.server_pubkey)?; + log::debug!("Tunnel {tunnel} public key decoded."); + let mut peer = Peer::new(peer_key); + + log::debug!("Parsing tunnel {tunnel} endpoint: {}", tunnel.endpoint); + peer.set_endpoint(&tunnel.endpoint)?; + peer.persistent_keepalive_interval = Some( + tunnel + .persistent_keep_alive + .try_into() + .expect("Failed to parse persistent keep alive"), + ); + log::debug!("Parsed tunnel {tunnel} endpoint: {}", tunnel.endpoint); + + if let Some(psk) = &tunnel.preshared_key { + log::debug!("Decoding tunnel {tunnel} preshared key."); + let peer_psk = Key::from_str(psk)?; + log::debug!("Preshared key for tunnel {tunnel} decoded."); + peer.preshared_key = Some(peer_psk); + } + + log::debug!( + "Parsing tunnel {tunnel} allowed ips: {:?}", + tunnel.allowed_ips + ); + let allowed_ips = if tunnel.route_all_traffic { + log::debug!("Using all traffic routing for tunnel {tunnel}"); + vec![DEFAULT_ROUTE_IPV4.into(), DEFAULT_ROUTE_IPV6.into()] + } else { + let msg = match &tunnel.allowed_ips { + Some(ips) => format!("Using predefined location traffic for tunnel {tunnel}: {ips}"), + None => "No allowed IP addresses found in tunnel {tunnel} configuration".to_string(), + }; + log::debug!("{msg}"); + tunnel + .allowed_ips + .as_ref() + .map(|ips| ips.split(',').map(str::to_string).collect()) + .unwrap_or_default() + }; + for allowed_ip in &allowed_ips { + match IpAddrMask::from_str(allowed_ip.trim()) { + Ok(addr) => { + peer.allowed_ips.push(addr); + } + Err(err) => { + log::error!("Error parsing IP address {allowed_ip}: {err}"); + } + } + } + log::debug!("Parsed tunnel {tunnel} allowed IPs: {:?}", peer.allowed_ips); + + log::debug!("Looking for a free port for interface {interface_name}."); + let Some(port) = find_free_tcp_port() else { + let msg = format!( + "Couldn't find free port for interface {interface_name} while setting up tunnel \ + {tunnel}" + ); + log::error!("{msg}"); + return Err(Error::InternalError(msg)); + }; + log::debug!("Found free port: {port} for interface {interface_name}."); + + let addresses = tunnel + .address + .split(',') + .map(str::trim) + .map(IpAddrMask::from_str) + .collect::>() + .map_err(|err| { + let msg = format!("Failed to parse IP addresses '{}': {err}", tunnel.address); + log::error!("{msg}"); + Error::InternalError(msg) + })?; + let interface_config = InterfaceConfiguration { + name: interface_name.clone(), + prvkey: tunnel.prvkey.clone(), + addresses, + port, + peers: vec![peer.clone()], + mtu, + fwmark: None, + }; + + log::debug!("Creating interface {interface_config:?}"); + let request = CreateInterfaceRequest { + config: Some(interface_config.clone().into()), + dns: tunnel.dns.clone(), + }; + if let Some(pre_up) = &tunnel.pre_up { + log::debug!( + "Executing defined PreUp command before setting up the interface {} for the tunnel \ + {tunnel}: {pre_up}", + interface_config.name + ); + let _ = execute_command(pre_up); + log::info!( + "Executed defined PreUp command before setting up the interface {} for the tunnel \ + {tunnel}: {pre_up}", + interface_config.name + ); + } + if let Err(error) = DAEMON_CLIENT.clone().create_interface(request).await { + log::error!( + "Failed to create a network interface ({}) for tunnel {tunnel}: {error}", + interface_config.name + ); + return Err(Error::InternalError(format!( + "Failed to create a network interface ({}) for tunnel {tunnel}, error message: {}. \ + Check logs for more details.", + interface_config.name, + error.message() + ))); + } + + log::info!( + "Network interface {} for tunnel {tunnel} created successfully.", + interface_config.name + ); + if let Some(post_up) = &tunnel.post_up { + log::debug!( + "Executing defined PostUp command after setting up the interface {} for the tunnel \ + {tunnel}: {post_up}", + interface_config.name + ); + let _ = execute_command(post_up); + log::info!( + "Executed defined PostUp command after setting up the interface {} for the tunnel \ + {tunnel}: {post_up}", + interface_config.name + ); + } + log::debug!( + "Created interface {} with config: {interface_config:?}", + interface_config.name + ); + + Ok(interface_name) +} + +pub async fn disconnect_interface(active_connection: &ActiveConnection) -> Result<(), Error> { + log::debug!( + "Disconnecting interface {}.", + active_connection.interface_name + ); + let location_id = active_connection.location_id; + let interface_name = active_connection.interface_name.clone(); + + let Some(location) = Location::find_by_id(&*DB_POOL, location_id).await? else { + log::error!( + "Error while disconnecting interface {interface_name}, location with ID \ + {location_id} not found" + ); + return Err(Error::NotFound); + }; + + let request = RemoveInterfaceRequest { + interface_name, + endpoint: location.endpoint.clone(), + }; + log::debug!( + "Sending request to the background service to remove interface {} for location {}...", + active_connection.interface_name, + location.name + ); + if let Err(error) = DAEMON_CLIENT.clone().remove_interface(request).await { + let msg = if error.code() == Code::Unavailable { + format!( + "Couldn't remove interface {}. Background service is unavailable. \ + Please make sure the service is running. Error: {error}.", + active_connection.interface_name + ) + } else { + format!( + "Failed to send a request to the background service to remove interface \ + {}. Error: {error}.", + active_connection.interface_name + ) + }; + log::error!("{msg}"); + } + + log::info!( + "Interface {} for location {} disconnected.", + active_connection.interface_name, + location.name + ); + Ok(()) +} + +pub fn execute_command(command: &str) -> Result<(), Error> { + log::debug!("Executing command: {command}"); + let mut command_parts = command.split_whitespace(); + + if let Some(command) = command_parts.next() { + let output = Command::new(command).args(command_parts).output()?; + + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + log::debug!("Command {command} executed successfully. Stdout: {stdout}"); + if !stderr.is_empty() { + log::error!("Command produced the following output on stderr: {stderr}"); + } + } else { + let stderr = String::from_utf8_lossy(&output.stderr); + log::error!("Error while executing command: {command}. Stderr: {stderr}"); + } + } + Ok(()) +} diff --git a/src-tauri/src/database/mod.rs b/src-tauri/core/src/database/mod.rs similarity index 98% rename from src-tauri/src/database/mod.rs rename to src-tauri/core/src/database/mod.rs index 99866016..5cbdcaa2 100644 --- a/src-tauri/src/database/mod.rs +++ b/src-tauri/core/src/database/mod.rs @@ -15,7 +15,7 @@ const DB_NAME: &str = "defguard.db"; pub mod models; -pub(crate) type DbPool = SqlitePool; +pub type DbPool = SqlitePool; pub static DB_POOL: LazyLock = LazyLock::new(|| { let db_url = prepare_db_url().expect("Wrong database URL."); @@ -96,7 +96,7 @@ fn prepare_db_url() -> Result { pub async fn handle_db_migrations() { debug!("Running database migrations, if there are any."); - sqlx::migrate!() + sqlx::migrate!("../migrations") .run(&*DB_POOL) .await .expect("Failed to apply database migrations."); diff --git a/src-tauri/src/database/models/connection.rs b/src-tauri/core/src/database/models/connection.rs similarity index 91% rename from src-tauri/src/database/models/connection.rs rename to src-tauri/core/src/database/models/connection.rs index f816c226..0ba0ab81 100644 --- a/src-tauri/src/database/models/connection.rs +++ b/src-tauri/core/src/database/models/connection.rs @@ -14,7 +14,7 @@ pub struct Connection { } impl Connection { - pub(crate) async fn save<'e, E>(self, executor: E) -> Result, Error> + pub async fn save<'e, E>(self, executor: E) -> Result, Error> where E: SqliteExecutor<'e>, { @@ -36,7 +36,7 @@ impl Connection { }) } - pub(crate) async fn latest_by_location_id<'e, E>( + pub async fn latest_by_location_id<'e, E>( executor: E, location_id: Id, ) -> Result>, Error> @@ -81,10 +81,7 @@ impl From for CommonConnectionInfo { } impl ConnectionInfo { - pub(crate) async fn all_by_location_id<'e, E>( - executor: E, - location_id: Id, - ) -> Result, Error> + pub async fn all_by_location_id<'e, E>(executor: E, location_id: Id) -> Result, Error> where E: SqliteExecutor<'e>, { @@ -130,11 +127,7 @@ pub struct ActiveConnection { impl ActiveConnection { #[must_use] - pub(crate) fn new( - location_id: Id, - interface_name: String, - connection_type: ConnectionType, - ) -> Self { + pub fn new(location_id: Id, interface_name: String, connection_type: ConnectionType) -> Self { let start = Utc::now().naive_utc(); Self { location_id, diff --git a/src-tauri/src/database/models/instance.rs b/src-tauri/core/src/database/models/instance.rs similarity index 84% rename from src-tauri/src/database/models/instance.rs rename to src-tauri/core/src/database/models/instance.rs index 9779e149..d01fead7 100644 --- a/src-tauri/src/database/models/instance.rs +++ b/src-tauri/core/src/database/models/instance.rs @@ -26,8 +26,8 @@ impl fmt::Display for Instance { } } -impl From for Instance { - fn from(instance_info: proto::defguard::client_types::InstanceInfo) -> Self { +impl From for Instance { + fn from(instance_info: proto::client_types::InstanceInfo) -> Self { let client_traffic_policy = ClientTrafficPolicy::from(&instance_info); Self { id: NoId, @@ -45,7 +45,7 @@ impl From for Instance { } impl Instance { - pub(crate) async fn save<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> + pub async fn save<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> where E: SqliteExecutor<'e>, { @@ -85,7 +85,7 @@ impl Instance { Ok(instances) } - pub(crate) async fn find_by_id<'e, E>(executor: E, id: Id) -> Result, sqlx::Error> + pub async fn find_by_id<'e, E>(executor: E, id: Id) -> Result, sqlx::Error> where E: SqliteExecutor<'e>, { @@ -101,7 +101,23 @@ impl Instance { Ok(instance) } - pub(crate) async fn delete_by_id<'e, E>(executor: E, id: Id) -> Result<(), sqlx::Error> + pub async fn find_by_name<'e, E>(executor: E, name: &str) -> Result, sqlx::Error> + where + E: SqliteExecutor<'e>, + { + let instance = query_as!( + Self, + "SELECT id \"id: _\", name, uuid, url, proxy_url, username, token \"token?\", \ + client_traffic_policy, enterprise_enabled, openid_display_name \ + FROM instance WHERE name = $1;", + name + ) + .fetch_optional(executor) + .await?; + Ok(instance) + } + + pub async fn delete_by_id<'e, E>(executor: E, id: Id) -> Result<(), sqlx::Error> where E: SqliteExecutor<'e>, { @@ -112,7 +128,7 @@ impl Instance { Ok(()) } - pub(crate) async fn delete<'e, E>(&self, executor: E) -> Result<(), sqlx::Error> + pub async fn delete<'e, E>(&self, executor: E) -> Result<(), sqlx::Error> where E: SqliteExecutor<'e>, { @@ -120,7 +136,7 @@ impl Instance { Ok(()) } - pub(crate) async fn all_with_token<'e, E>(executor: E) -> Result, sqlx::Error> + pub async fn all_with_token<'e, E>(executor: E) -> Result, sqlx::Error> where E: SqliteExecutor<'e>, { @@ -138,8 +154,8 @@ impl Instance { } // This compares proto::InstanceInfo, not to be confused with regular InstanceInfo defined below -impl PartialEq for Instance { - fn eq(&self, other: &proto::defguard::client_types::InstanceInfo) -> bool { +impl PartialEq for Instance { + fn eq(&self, other: &proto::client_types::InstanceInfo) -> bool { let other_policy = ClientTrafficPolicy::from(other); self.name == other.name && self.uuid == other.id @@ -223,8 +239,8 @@ pub enum ClientTrafficPolicy { } /// Retrieves `ClientTrafficPolicy` from `proto::InstanceInfo` while ensuring backwards compatibility -impl From<&proto::defguard::client_types::InstanceInfo> for ClientTrafficPolicy { - fn from(instance: &proto::defguard::client_types::InstanceInfo) -> Self { +impl From<&proto::client_types::InstanceInfo> for ClientTrafficPolicy { + fn from(instance: &proto::client_types::InstanceInfo) -> Self { match ( instance.client_traffic_policy, #[allow(deprecated)] diff --git a/src-tauri/src/database/models/location.rs b/src-tauri/core/src/database/models/location.rs similarity index 92% rename from src-tauri/src/database/models/location.rs rename to src-tauri/core/src/database/models/location.rs index 543d480b..110fa63b 100644 --- a/src-tauri/src/database/models/location.rs +++ b/src-tauri/core/src/database/models/location.rs @@ -10,17 +10,12 @@ use sqlx::{prelude::Type, query, query_as, query_scalar, SqliteExecutor}; #[cfg(not(target_os = "macos"))] use super::wireguard_keys::WireguardKeys; use super::{Id, NoId}; -#[cfg(not(target_os = "macos"))] -use crate::{ - database::DbPool, - utils::{DEFAULT_ROUTE_IPV4, DEFAULT_ROUTE_IPV6}, -}; -use crate::{ - error::Error, - proto::defguard::client_types::{ - LocationMfaMode as ProtoLocationMfaMode, ServiceLocationMode as ProtoServiceLocationMode, - }, +use crate::error::Error; +use crate::proto::client_types::{ + LocationMfaMode as ProtoLocationMfaMode, ServiceLocationMode as ProtoServiceLocationMode, }; +#[cfg(not(target_os = "macos"))] +use crate::{database::DbPool, DEFAULT_ROUTE_IPV4, DEFAULT_ROUTE_IPV6}; #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Type)] #[repr(u32)] @@ -76,7 +71,7 @@ pub enum LocationMfaMethod { MobileApprove = 4, } -pub(crate) fn infer_mfa_method( +pub fn infer_mfa_method( mode: LocationMfaMode, method: Option, ) -> Option { @@ -126,10 +121,7 @@ impl fmt::Display for Location { impl Location { /// Ignores service locations #[cfg(any(windows, target_os = "macos"))] - pub(crate) async fn all<'e, E>( - executor: E, - include_service_locations: bool, - ) -> sqlx::Result> + pub async fn all<'e, E>(executor: E, include_service_locations: bool) -> sqlx::Result> where E: SqliteExecutor<'e>, { @@ -151,10 +143,7 @@ impl Location { } #[cfg(any(windows, target_os = "macos"))] - pub(crate) async fn exist<'e, E>( - executor: E, - include_service_locations: bool, - ) -> sqlx::Result + pub async fn exist<'e, E>(executor: E, include_service_locations: bool) -> sqlx::Result where E: SqliteExecutor<'e>, { @@ -170,7 +159,7 @@ impl Location { Ok(result != 0) } - pub(crate) async fn save<'e, E>(&mut self, executor: E) -> sqlx::Result<()> + pub async fn save<'e, E>(&mut self, executor: E) -> sqlx::Result<()> where E: SqliteExecutor<'e>, { @@ -203,10 +192,7 @@ impl Location { Ok(()) } - pub(crate) async fn find_by_id<'e, E>( - executor: E, - location_id: Id, - ) -> sqlx::Result> + pub async fn find_by_id<'e, E>(executor: E, location_id: Id) -> sqlx::Result> where E: SqliteExecutor<'e>, { @@ -224,7 +210,7 @@ impl Location { .await } - pub(crate) async fn find_by_instance_id<'e, E>( + pub async fn find_by_instance_id<'e, E>( executor: E, instance_id: Id, include_service_locations: bool, @@ -250,7 +236,7 @@ impl Location { .await } - pub(crate) async fn find_by_public_key<'e, E>(executor: E, pubkey: &str) -> sqlx::Result + pub async fn find_by_public_key<'e, E>(executor: E, pubkey: &str) -> sqlx::Result where E: SqliteExecutor<'e>, { @@ -268,7 +254,7 @@ impl Location { .await } - pub(crate) async fn delete<'e, E>(&self, executor: E) -> sqlx::Result<()> + pub async fn delete<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: SqliteExecutor<'e>, { @@ -279,7 +265,7 @@ impl Location { } /// Disables all traffic for locations related to the given instance - pub(crate) async fn disable_all_traffic_for_all<'e, E>( + pub async fn disable_all_traffic_for_all<'e, E>( executor: E, instance_id: Id, ) -> Result<(), Error> @@ -295,7 +281,7 @@ impl Location { Ok(()) } - pub(crate) fn mfa_enabled(&self) -> bool { + pub fn mfa_enabled(&self) -> bool { match self.location_mfa_mode { LocationMfaMode::Disabled => false, LocationMfaMode::Internal | LocationMfaMode::External => true, @@ -303,7 +289,7 @@ impl Location { } #[cfg(not(target_os = "macos"))] - pub(crate) async fn interface_configuration( + pub async fn interface_configuration( &self, pool: &DbPool, interface_name: String, @@ -418,7 +404,7 @@ impl Location { } impl Location { - pub(crate) async fn save<'e, E>(self, executor: E) -> sqlx::Result> + pub async fn save<'e, E>(self, executor: E) -> sqlx::Result> where E: SqliteExecutor<'e>, { diff --git a/src-tauri/src/database/models/location_stats.rs b/src-tauri/core/src/database/models/location_stats.rs similarity index 92% rename from src-tauri/src/database/models/location_stats.rs rename to src-tauri/core/src/database/models/location_stats.rs index c0740ceb..4b1f12a3 100644 --- a/src-tauri/src/database/models/location_stats.rs +++ b/src-tauri/core/src/database/models/location_stats.rs @@ -6,18 +6,18 @@ use serde::{Deserialize, Serialize}; use sqlx::{query, query_as, query_scalar, SqliteExecutor}; use super::{location::Location, Id, NoId, PURGE_DURATION}; -use crate::{commands::DateTimeAggregation, error::Error, CommonLocationStats, ConnectionType}; +use crate::{error::Error, CommonLocationStats, ConnectionType, DateTimeAggregation}; #[derive(Debug, Serialize, Deserialize)] pub struct LocationStats { id: I, - pub(crate) location_id: Id, + pub location_id: Id, upload: i64, download: i64, - pub(crate) last_handshake: i64, - pub(crate) collected_at: NaiveDateTime, + pub last_handshake: i64, + pub collected_at: NaiveDateTime, listen_port: u32, - pub(crate) persistent_keepalive_interval: Option, + pub persistent_keepalive_interval: Option, } impl From> for CommonLocationStats { @@ -61,7 +61,7 @@ where impl LocationStats { // Although not used on macOS, allow dead code for `sqlx prepare`. #[cfg_attr(target_os = "macos", allow(dead_code))] - pub(crate) async fn get_name<'e, E>(&self, executor: E) -> Result + pub async fn get_name<'e, E>(&self, executor: E) -> Result where E: SqliteExecutor<'e>, { @@ -73,7 +73,7 @@ impl LocationStats { impl LocationStats { #[must_use] - pub(crate) fn new( + pub fn new( location_id: Id, upload: i64, download: i64, @@ -93,7 +93,7 @@ impl LocationStats { } } - pub(crate) async fn save<'e, E>(self, executor: E) -> Result, Error> + pub async fn save<'e, E>(self, executor: E) -> Result, Error> where E: SqliteExecutor<'e>, { @@ -127,7 +127,7 @@ impl LocationStats { } impl LocationStats { - pub(crate) async fn all_by_location_id<'e, E>( + pub async fn all_by_location_id<'e, E>( executor: E, location_id: Id, from: &NaiveDateTime, @@ -168,7 +168,7 @@ impl LocationStats { Ok(stats) } - pub(crate) async fn latest_by_download_change<'e, E>( + pub async fn latest_by_download_change<'e, E>( executor: E, location_id: Id, ) -> Result, Error> diff --git a/src-tauri/src/database/models/mod.rs b/src-tauri/core/src/database/models/mod.rs similarity index 100% rename from src-tauri/src/database/models/mod.rs rename to src-tauri/core/src/database/models/mod.rs diff --git a/src-tauri/src/database/models/settings.rs b/src-tauri/core/src/database/models/settings.rs similarity index 100% rename from src-tauri/src/database/models/settings.rs rename to src-tauri/core/src/database/models/settings.rs diff --git a/src-tauri/src/database/models/tunnel.rs b/src-tauri/core/src/database/models/tunnel.rs similarity index 94% rename from src-tauri/src/database/models/tunnel.rs rename to src-tauri/core/src/database/models/tunnel.rs index 5e92c9b5..f81ff18c 100644 --- a/src-tauri/src/database/models/tunnel.rs +++ b/src-tauri/core/src/database/models/tunnel.rs @@ -8,8 +8,8 @@ use sqlx::{query, query_as, query_scalar, Error as SqlxError, SqliteExecutor}; use super::{connection::ActiveConnection, Id, NoId, PURGE_DURATION}; use crate::{ - commands::DateTimeAggregation, error::Error, CommonConnection, CommonConnectionInfo, - CommonLocationStats, ConnectionType, + error::Error, CommonConnection, CommonConnectionInfo, CommonLocationStats, ConnectionType, + DateTimeAggregation, }; #[serde_as] @@ -57,7 +57,7 @@ impl fmt::Display for Tunnel { } impl Tunnel { - pub(crate) async fn save<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> + pub async fn save<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> where E: SqliteExecutor<'e>, { @@ -90,7 +90,7 @@ impl Tunnel { Ok(()) } - pub(crate) async fn delete<'e, E>(&self, executor: E) -> Result<(), Error> + pub async fn delete<'e, E>(&self, executor: E) -> Result<(), Error> where E: SqliteExecutor<'e>, { @@ -98,10 +98,7 @@ impl Tunnel { Ok(()) } - pub(crate) async fn find_by_id<'e, E>( - executor: E, - tunnel_id: Id, - ) -> Result, SqlxError> + pub async fn find_by_id<'e, E>(executor: E, tunnel_id: Id) -> Result, SqlxError> where E: SqliteExecutor<'e>, { @@ -116,7 +113,7 @@ impl Tunnel { .await } - pub(crate) async fn all<'e, E>(executor: E) -> Result, SqlxError> + pub async fn all<'e, E>(executor: E) -> Result, SqlxError> where E: SqliteExecutor<'e>, { @@ -132,7 +129,7 @@ impl Tunnel { Ok(tunnels) } - pub(crate) async fn find_by_server_public_key<'e, E>( + pub async fn find_by_server_public_key<'e, E>( executor: E, pubkey: &str, ) -> Result @@ -151,7 +148,7 @@ impl Tunnel { .await } - pub(crate) async fn delete_by_id<'e, E>(executor: E, id: Id) -> Result<(), Error> + pub async fn delete_by_id<'e, E>(executor: E, id: Id) -> Result<(), Error> where E: SqliteExecutor<'e>, { @@ -166,7 +163,7 @@ impl Tunnel { impl Tunnel { #[allow(clippy::too_many_arguments)] #[must_use] - pub(crate) fn new( + pub fn new( name: String, pubkey: String, prvkey: String, @@ -203,7 +200,7 @@ impl Tunnel { } } - pub(crate) async fn save<'e, E>(self, executor: E) -> Result, SqlxError> + pub async fn save<'e, E>(self, executor: E) -> Result, SqlxError> where E: SqliteExecutor<'e>, { @@ -255,13 +252,13 @@ impl Tunnel { #[derive(Debug, Serialize, Deserialize)] pub struct TunnelStats { id: I, - pub(crate) tunnel_id: Id, + pub tunnel_id: Id, upload: i64, download: i64, - pub(crate) last_handshake: i64, - pub(crate) collected_at: NaiveDateTime, + pub last_handshake: i64, + pub collected_at: NaiveDateTime, listen_port: u32, - pub(crate) persistent_keepalive_interval: u16, + pub persistent_keepalive_interval: u16, } impl TunnelStats { @@ -331,7 +328,7 @@ impl TunnelStats { } impl TunnelStats { - pub(crate) async fn all_by_tunnel_id<'e, E>( + pub async fn all_by_tunnel_id<'e, E>( executor: E, tunnel_id: Id, from: &NaiveDateTime, @@ -366,7 +363,7 @@ impl TunnelStats { Ok(stats) } - pub(crate) async fn latest_by_download_change<'e, E>( + pub async fn latest_by_download_change<'e, E>( executor: E, tunnel_id: Id, ) -> Result, Error> @@ -632,7 +629,7 @@ mod tests { } } - #[sqlx::test] + #[sqlx::test(migrations = "../migrations")] async fn purge_stats(pool: SqlitePool) { let tunnel = Tunnel::new( "test".into(), diff --git a/src-tauri/src/database/models/wireguard_keys.rs b/src-tauri/core/src/database/models/wireguard_keys.rs similarity index 100% rename from src-tauri/src/database/models/wireguard_keys.rs rename to src-tauri/core/src/database/models/wireguard_keys.rs diff --git a/src-tauri/src/error.rs b/src-tauri/core/src/error.rs similarity index 98% rename from src-tauri/src/error.rs rename to src-tauri/core/src/error.rs index 28a6db94..c9d815a6 100644 --- a/src-tauri/src/error.rs +++ b/src-tauri/core/src/error.rs @@ -27,7 +27,7 @@ pub enum Error { #[error("Object not found")] NotFound, #[error("Tauri error: {0}")] - Tauri(#[from] tauri::Error), + Tauri(String), #[error("Failed to parse str to enum")] StrumError(#[from] strum::ParseError), #[error("Required resource not found {0}")] diff --git a/src-tauri/core/src/events.rs b/src-tauri/core/src/events.rs new file mode 100644 index 00000000..1a3504e8 --- /dev/null +++ b/src-tauri/core/src/events.rs @@ -0,0 +1,36 @@ +// Match src/pages/client/types.ts. +#[non_exhaustive] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum EventKey { + ConnectionChanged, + InstanceUpdate, + LocationUpdate, + AppVersionFetch, + ConfigChanged, + DeadConnectionDropped, + DeadConnectionReconnected, + ApplicationConfigChanged, + AddInstance, + MfaTrigger, + VersionMismatch, + UuidMismatch, +} + +impl From for &'static str { + fn from(key: EventKey) -> &'static str { + match key { + EventKey::ConnectionChanged => "connection-changed", + EventKey::InstanceUpdate => "instance-update", + EventKey::LocationUpdate => "location-update", + EventKey::AppVersionFetch => "app-version-fetch", + EventKey::ConfigChanged => "config-changed", + EventKey::DeadConnectionDropped => "dead-connection-dropped", + EventKey::DeadConnectionReconnected => "dead-connection-reconnected", + EventKey::ApplicationConfigChanged => "application-config-changed", + EventKey::AddInstance => "add-instance", + EventKey::MfaTrigger => "mfa-trigger", + EventKey::VersionMismatch => "version-mismatch", + EventKey::UuidMismatch => "uuid-mismatch", + } + } +} diff --git a/src-tauri/core/src/lib.rs b/src-tauri/core/src/lib.rs new file mode 100644 index 00000000..db65bdb2 --- /dev/null +++ b/src-tauri/core/src/lib.rs @@ -0,0 +1,191 @@ +use std::{fmt, path::PathBuf}; + +use chrono::{Duration, NaiveDateTime, Utc}; +use database::models::Id; +use serde::{Deserialize, Serialize}; +#[cfg(unix)] +use std::{ + fs::{set_permissions, Permissions}, + os::unix::fs::PermissionsExt, +}; + +pub mod app_config; +pub mod connection; +pub mod database; +pub mod error; +pub mod events; +pub mod proxy; +pub mod version; +pub mod wg_config; + +// Re-export proto module for backward compatibility within core. +pub use defguard_client_proto::defguard as proto; + +use crate::database::models::NoId; + +#[macro_use] +extern crate log; + +const BUNDLE_IDENTIFIER: &str = "net.defguard"; + +/// Returns the path to the user's data directory. +#[must_use] +pub fn app_data_dir() -> Option { + dirs_next::data_dir().map(|dir| dir.join(BUNDLE_IDENTIFIER)) +} + +/// Ensures path has appropriate permissions set (dg25-28): +/// - 700 for directories +/// - 600 for files +#[cfg(unix)] +pub fn set_perms(path: &std::path::Path) { + let perms = if path.is_dir() { 0o700 } else { 0o600 }; + if let Err(err) = set_permissions(path, Permissions::from_mode(perms)) { + log::warn!( + "Failed to set permissions on path {}: {err}", + path.display() + ); + } +} + +/// Location type used in commands to check if we use tunnel or location +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Serialize)] +pub enum ConnectionType { + Tunnel, + Location, +} + +impl fmt::Display for ConnectionType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ConnectionType::Tunnel => write!(f, "tunnel"), + ConnectionType::Location => write!(f, "location"), + } + } +} + +/// Common fields for Tunnel and Location +#[derive(Debug, Serialize, Deserialize)] +pub struct CommonWireguardFields { + pub instance_id: Id, + pub network_id: Id, + pub name: String, + pub address: String, + pub pubkey: String, + pub endpoint: String, + pub allowed_ips: String, + pub dns: Option, + pub route_all_traffic: bool, +} + +/// Common fields for Connection and TunnelConnection due to shared command +#[derive(Debug, Serialize, Deserialize)] +pub struct CommonConnection { + pub id: I, + pub location_id: Id, + pub start: NaiveDateTime, + pub end: NaiveDateTime, + pub connection_type: ConnectionType, +} + +/// Common fields for LocationStats and TunnelStats due to shared command +#[derive(Debug, Serialize, Deserialize)] +pub struct CommonLocationStats { + pub id: I, + pub location_id: Id, + pub upload: i64, + pub download: i64, + pub last_handshake: i64, + pub collected_at: NaiveDateTime, + pub listen_port: u32, + pub persistent_keepalive_interval: Option, + pub connection_type: ConnectionType, +} + +/// Common fields for ConnectionInfo and TunnelConnectionInfo due to shared command +#[derive(Debug, Serialize)] +pub struct CommonConnectionInfo { + pub id: Id, + pub location_id: Id, + pub start: NaiveDateTime, + pub end: NaiveDateTime, + pub upload: Option, + pub download: Option, +} + +pub const DEFAULT_ROUTE_IPV4: &str = "0.0.0.0/0"; +pub const DEFAULT_ROUTE_IPV6: &str = "::/0"; + +pub enum DateTimeAggregation { + Hour, + Second, +} + +impl DateTimeAggregation { + #[must_use] + pub fn fstring(&self) -> &'static str { + match self { + Self::Hour => "%Y-%m-%d %H:00:00", + Self::Second => "%Y-%m-%d %H:%M:%S", + } + } +} + +pub fn get_aggregation(from: NaiveDateTime) -> Result { + let aggregation = match Utc::now().naive_utc() - from { + duration if duration >= Duration::hours(8) => Ok(DateTimeAggregation::Hour), + duration if duration < Duration::zero() => Err(error::Error::InternalError(format!( + "Negative duration between dates: now ({}) and {from}", + Utc::now().naive_utc(), + ))), + _ => Ok(DateTimeAggregation::Second), + }?; + Ok(aggregation) +} + +use database::models::location::{ + infer_mfa_method, Location, LocationMfaMode, ServiceLocationMode, +}; +use defguard_client_proto::defguard::client_types::DeviceConfig; + +#[must_use] +pub fn into_location(dev_config: DeviceConfig, instance_id: Id) -> Location { + use LocationMfaMode as MfaMode; + use ServiceLocationMode as SLocationMode; + + let location_mfa_mode = match dev_config.location_mfa_mode { + Some(_location_mfa_mode) => dev_config.location_mfa_mode().into(), + None => + { + #[allow(deprecated)] + if dev_config.mfa_enabled { + MfaMode::Internal + } else { + MfaMode::Disabled + } + } + }; + + let service_location_mode = match dev_config.service_location_mode { + Some(_service_location_mode) => dev_config.service_location_mode().into(), + None => SLocationMode::Disabled, + }; + + Location { + id: NoId, + instance_id, + network_id: dev_config.network_id, + name: dev_config.network_name, + address: dev_config.assigned_ip, + pubkey: dev_config.pubkey, + endpoint: dev_config.endpoint, + allowed_ips: dev_config.allowed_ips, + dns: dev_config.dns, + route_all_traffic: false, + keepalive_interval: dev_config.keepalive_interval.into(), + location_mfa_mode, + service_location_mode, + mfa_method: infer_mfa_method(location_mfa_mode, None), + posture_check_required: dev_config.posture_check_required.unwrap_or_default(), + } +} diff --git a/src-tauri/core/src/proxy.rs b/src-tauri/core/src/proxy.rs new file mode 100644 index 00000000..70b3537e --- /dev/null +++ b/src-tauri/core/src/proxy.rs @@ -0,0 +1,46 @@ +use std::{env, time::Duration}; + +use base64::{prelude::BASE64_STANDARD, Engine}; +use prost::Message; +use reqwest::{Client, Response, Url}; +use serde::Serialize; + +use crate::version::{CLIENT_PLATFORM_HEADER, CLIENT_VERSION_HEADER, PKG_VERSION}; +use defguard_client_proto::defguard::client_types::ClientPlatformInfo; + +const HTTP_REQ_TIMEOUT: Duration = Duration::from_secs(5); + +/// Build a base64-encoded `ClientPlatformInfo` header value. +#[must_use] +pub fn construct_platform_header() -> String { + let os = os_info::get(); + + let platform_info = ClientPlatformInfo { + os_family: env::consts::FAMILY.to_string(), + os_type: env::consts::OS.to_string(), + version: os.version().to_string(), + edition: os.edition().map(str::to_string), + codename: os.codename().map(str::to_string), + bitness: Some(os.bitness().to_string()), + architecture: Some(env::consts::ARCH.to_string()), + }; + + log::debug!("Constructed platform info header: {platform_info:?}"); + + BASE64_STANDARD.encode(platform_info.encode_to_vec()) +} + +/// Send a JSON POST request with the standard client version/platform headers and a short timeout. +pub async fn post_with_headers( + url: Url, + data: &T, +) -> Result { + Client::new() + .post(url) + .json(data) + .header(CLIENT_VERSION_HEADER, PKG_VERSION) + .header(CLIENT_PLATFORM_HEADER, construct_platform_header()) + .timeout(HTTP_REQ_TIMEOUT) + .send() + .await +} diff --git a/src-tauri/core/src/version.rs b/src-tauri/core/src/version.rs new file mode 100644 index 00000000..a053a791 --- /dev/null +++ b/src-tauri/core/src/version.rs @@ -0,0 +1,8 @@ +pub use semver::Version; + +pub const MIN_CORE_VERSION: Version = Version::new(1, 6, 0); +pub const MIN_PROXY_VERSION: Version = Version::new(1, 6, 0); +pub const CLIENT_VERSION_HEADER: &str = "defguard-client-version"; +pub const CLIENT_PLATFORM_HEADER: &str = "defguard-client-platform"; +pub const LOG_FILENAME: &str = "defguard-client"; +pub use defguard_client_common::VERSION as PKG_VERSION; diff --git a/src-tauri/src/wg_config.rs b/src-tauri/core/src/wg_config.rs similarity index 100% rename from src-tauri/src/wg_config.rs rename to src-tauri/core/src/wg_config.rs diff --git a/src-tauri/daemon/Cargo.toml b/src-tauri/daemon/Cargo.toml new file mode 100644 index 00000000..426e48da --- /dev/null +++ b/src-tauri/daemon/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "defguard-client-service" +description = "Defguard client daemon service" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license = "AGPL-3.0" +rust-version.workspace = true +version.workspace = true + +[dependencies] +anyhow = "1.0" +clap = { workspace = true } +defguard-client-common = { path = "../common" } +defguard-client-proto = { path = "../client-proto" } +defguard-client-posture = { path = "../enterprise/posture" } +defguard-client-service-locations = { path = "../enterprise/service-locations" } +defguard_wireguard_rs = { workspace = true } +log.workspace = true +serde.workspace = true +thiserror = { workspace = true } +tokio = { version = "1", features = ["net", "rt-multi-thread", "signal", "sync", "time"] } +tokio-stream = { version = "0.1", features = ["net"] } +tonic = { workspace = true } +tracing = { workspace = true } +tracing-appender = "0.2" +tracing-subscriber = { workspace = true } + +[target.'cfg(unix)'.dependencies] +nix = { version = "0.30", features = ["fs", "user"] } + +[target.'cfg(windows)'.dependencies] +async-stream = "0.3" +futures-core = "0.3" +tokio = { version = "1", features = ["net", "rt-multi-thread", "signal", "sync", "time"] } +windows-core = "0.61" +windows-service = "0.8" +windows-sys = "0.61" + +[[bin]] +name = "defguard-service" +path = "src/bin/defguard-service.rs" diff --git a/src-tauri/src/bin/defguard-service.rs b/src-tauri/daemon/src/bin/defguard-service.rs similarity index 79% rename from src-tauri/src/bin/defguard-service.rs rename to src-tauri/daemon/src/bin/defguard-service.rs index 7b4ec0cb..9aab6c2e 100644 --- a/src-tauri/src/bin/defguard-service.rs +++ b/src-tauri/daemon/src/bin/defguard-service.rs @@ -7,7 +7,7 @@ #[tokio::main] async fn main() -> anyhow::Result<()> { use clap::Parser; - use defguard_client::service::{config::Config, daemon::run_server, utils::logging_setup}; + use defguard_client_service::{config::Config, daemon::run_server, utils::logging_setup}; // parse config let config: Config = Config::parse(); @@ -21,5 +21,5 @@ async fn main() -> anyhow::Result<()> { #[cfg(windows)] fn main() -> windows_service::Result<()> { - defguard_client::service::windows::run() + defguard_client_service::windows::run() } diff --git a/src-tauri/daemon/src/config.rs b/src-tauri/daemon/src/config.rs new file mode 100644 index 00000000..e2f8528b --- /dev/null +++ b/src-tauri/daemon/src/config.rs @@ -0,0 +1,23 @@ +use clap::Parser; + +#[cfg(windows)] +pub const DEFAULT_LOG_DIR: &str = "/Logs/defguard-service"; +#[cfg(not(windows))] +pub const DEFAULT_LOG_DIR: &str = "/var/log/defguard-service"; + +#[derive(Debug, Parser, Clone)] +#[clap(about = "Defguard VPN client interface management service")] +#[command(version)] +pub struct Config { + /// Configures log level of defguard service logs + #[arg(long, env = "DEFGUARD_LOG_LEVEL", default_value = "info")] + pub log_level: String, + + /// Configures logging directory; it is meant for debugging only, so hide it. + #[arg(long, env = "DEFGUARD_LOG_DIR", default_value = DEFAULT_LOG_DIR, hide = true)] + pub log_dir: String, + + /// Defines how often (in seconds) interface statistics are sent to defguard client + #[arg(long, short = 'p', env = "DEFGUARD_STATS_PERIOD", default_value = "10")] + pub stats_period: u64, +} diff --git a/src-tauri/src/service/daemon.rs b/src-tauri/daemon/src/daemon.rs similarity index 91% rename from src-tauri/src/service/daemon.rs rename to src-tauri/daemon/src/daemon.rs index ac54904a..9cf4af49 100644 --- a/src-tauri/src/service/daemon.rs +++ b/src-tauri/daemon/src/daemon.rs @@ -7,7 +7,7 @@ use std::{ #[cfg(unix)] use std::{fs, os::unix::fs::PermissionsExt, path::Path}; -use common::dns_borrow; +use defguard_client_common::dns_borrow; use defguard_wireguard_rs::{ error::WireguardInterfaceError, InterfaceConfiguration, Kernel, WGApi, WireguardInterfaceApi, }; @@ -27,23 +27,21 @@ use tonic::{ use tracing::warn; use tracing::{debug, error, info, info_span, Instrument}; -use super::{ - config::Config, - proto::defguard::client::v1::{ - desktop_daemon_service_server::{DesktopDaemonService, DesktopDaemonServiceServer}, - CreateInterfaceRequest, DeleteServiceLocationsRequest, InterfaceData, - ReadInterfaceDataRequest, RemoveInterfaceRequest, SaveServiceLocationsRequest, - }, -}; -use crate::{ - enterprise::service_locations::ServiceLocationError, - service::proto::defguard::enterprise::posture::v2::DevicePostureData, VERSION, -}; +use crate::config::Config; #[cfg(windows)] -use crate::{ - enterprise::service_locations::ServiceLocationManager, - service::named_pipe::{get_named_pipe_server_stream, PIPE_NAME}, +use crate::named_pipe::{get_named_pipe_server_stream, PIPE_NAME}; +use crate::VERSION; +#[cfg(windows)] +use defguard_client_posture::inspector::device_posture_data; +use defguard_client_proto::defguard::client::v1::{ + desktop_daemon_service_server::{DesktopDaemonService, DesktopDaemonServiceServer}, + CreateInterfaceRequest, DeleteServiceLocationsRequest, InterfaceData, ReadInterfaceDataRequest, + RemoveInterfaceRequest, SaveServiceLocationsRequest, }; +use defguard_client_proto::defguard::enterprise::posture::v2::DevicePostureData; +use defguard_client_service_locations::ServiceLocationError; +#[cfg(windows)] +use defguard_client_service_locations::ServiceLocationManager; #[cfg(unix)] pub(super) const DAEMON_SOCKET_PATH: &str = "/var/run/defguard.socket"; @@ -520,7 +518,7 @@ impl DesktopDaemonService for DaemonService { _request: tonic::Request<()>, ) -> Result, Status> { debug!("Get posture data request received"); - Ok(Response::new(DevicePostureData::new())) + Ok(Response::new(device_posture_data())) } } @@ -539,21 +537,24 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { debug!("Binding socket file at {DAEMON_SOCKET_PATH}"); let uds = UnixListener::bind(DAEMON_SOCKET_PATH)?; - // change owner group for socket file - // get the group ID by name - let group = Group::from_name(DAEMON_SOCKET_GROUP)?.ok_or_else(|| { - error!("Group '{DAEMON_SOCKET_GROUP}' not found"); - crate::error::Error::InternalError(format!("Group '{DAEMON_SOCKET_GROUP}' not found")) - })?; + #[cfg(target_os = "linux")] + { + // change owner group for socket file + // get the group ID by name + let group = Group::from_name(DAEMON_SOCKET_GROUP)?.ok_or_else(|| { + error!("Group '{DAEMON_SOCKET_GROUP}' not found"); + crate::Error::Internal(format!("Group '{DAEMON_SOCKET_GROUP}' not found")) + })?; - // change ownership - keep current user, change group - debug!("Changing owner group of socket file at {DAEMON_SOCKET_PATH} to group {DAEMON_SOCKET_GROUP}"); - chown(DAEMON_SOCKET_PATH, None, Some(group.gid))?; + // change ownership - keep current user, change group + debug!("Changing owner group of socket file at {DAEMON_SOCKET_PATH} to group {DAEMON_SOCKET_GROUP}"); + chown(DAEMON_SOCKET_PATH, None, Some(group.gid))?; - // Set socket permissions to allow client access - // 0o660 allows read/write for owner and group only - debug!("Setting permissions for socket file at {DAEMON_SOCKET_PATH} to 0x660"); - fs::set_permissions(DAEMON_SOCKET_PATH, fs::Permissions::from_mode(0o660))?; + // Set socket permissions to allow client access + // 0o660 allows read/write for owner and group only + debug!("Setting permissions for socket file at {DAEMON_SOCKET_PATH} to 0x660"); + fs::set_permissions(DAEMON_SOCKET_PATH, fs::Permissions::from_mode(0o660))?; + } let uds_stream = UnixListenerStream::new(uds); @@ -561,7 +562,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { debug!("Defguard daemon configuration: {config:?}"); Server::builder() - .trace_fn(|_| tracing::info_span!("defguard_service")) + .trace_fn(|_| tracing::info_span!("defguard_client_service")) .add_service(DesktopDaemonServiceServer::new(daemon_service)) .serve_with_incoming(uds_stream) .await?; @@ -583,7 +584,7 @@ pub(crate) async fn run_server( debug!("Defguard daemon configuration: {config:?}"); Server::builder() - .trace_fn(|_| tracing::info_span!("defguard_service")) + .trace_fn(|_| tracing::info_span!("defguard_client_service")) .add_service(DesktopDaemonServiceServer::new(daemon_service)) .serve_with_incoming(stream) .await?; diff --git a/src-tauri/daemon/src/error.rs b/src-tauri/daemon/src/error.rs new file mode 100644 index 00000000..d1a1f11b --- /dev/null +++ b/src-tauri/daemon/src/error.rs @@ -0,0 +1,13 @@ +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("internal daemon error: {0}")] + Internal(String), + #[error("wireguard interface error: {0}")] + WireGuard(#[from] defguard_wireguard_rs::error::WireguardInterfaceError), + #[error("service location error: {0}")] + ServiceLocation(#[from] defguard_client_service_locations::ServiceLocationError), + #[error("conversion error: {0}")] + Conversion(String), + #[error("not found: {0}")] + NotFound(String), +} diff --git a/src-tauri/daemon/src/lib.rs b/src-tauri/daemon/src/lib.rs new file mode 100644 index 00000000..345540ee --- /dev/null +++ b/src-tauri/daemon/src/lib.rs @@ -0,0 +1,12 @@ +pub mod config; +pub mod daemon; +pub mod error; +pub mod utils; + +#[cfg(windows)] +pub mod named_pipe; +#[cfg(windows)] +pub mod windows; + +pub use defguard_client_common::VERSION; +pub use error::Error; diff --git a/src-tauri/src/service/named_pipe.rs b/src-tauri/daemon/src/named_pipe.rs similarity index 100% rename from src-tauri/src/service/named_pipe.rs rename to src-tauri/daemon/src/named_pipe.rs diff --git a/src-tauri/daemon/src/utils.rs b/src-tauri/daemon/src/utils.rs new file mode 100644 index 00000000..7796cf56 --- /dev/null +++ b/src-tauri/daemon/src/utils.rs @@ -0,0 +1,39 @@ +use std::io::stdout; + +use tracing::Level; +use tracing_appender::non_blocking::WorkerGuard; +use tracing_subscriber::{ + fmt, fmt::writer::MakeWriterExt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, + Layer, +}; + +pub fn logging_setup(log_dir: &str, log_level: &str) -> WorkerGuard { + // prepare log file appender + let file_appender = tracing_appender::rolling::daily(log_dir, "defguard-service.log"); + let (non_blocking, guard) = tracing_appender::non_blocking(file_appender); + + // prepare log level filter for stdout + let stdout_filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{log_level},hyper=info,h2=info").into()); + + // prepare log level filter for JSON file + let json_filter = EnvFilter::new("DEBUG,hyper=info,h2=info"); + + // prepare tracing layers + let stdout_layer = fmt::layer() + .pretty() + .with_writer(stdout.with_max_level(Level::DEBUG)) + .with_filter(stdout_filter); + let json_file_layer = fmt::layer() + .json() + .with_writer(non_blocking.with_max_level(Level::DEBUG)) + .with_filter(json_filter); + + // initialize tracing subscriber + tracing_subscriber::registry() + .with(stdout_layer) + .with(json_file_layer) + .init(); + + guard +} diff --git a/src-tauri/src/service/windows.rs b/src-tauri/daemon/src/windows.rs similarity index 97% rename from src-tauri/src/service/windows.rs rename to src-tauri/daemon/src/windows.rs index 00c5c0a0..07d30960 100644 --- a/src-tauri/src/service/windows.rs +++ b/src-tauri/daemon/src/windows.rs @@ -6,7 +6,6 @@ use std::{ }; use clap::Parser; -use error; use tokio::{runtime::Runtime, time::sleep}; use windows_service::{ define_windows_service, @@ -19,15 +18,13 @@ use windows_service::{ }; use crate::{ - enterprise::service_locations::{ - windows::{watch_for_login_logoff, watch_for_network_change}, - ServiceLocationError, ServiceLocationManager, - }, - service::{ - config::Config, - daemon::{run_server, DaemonError}, - utils::logging_setup, - }, + config::Config, + daemon::{run_server, DaemonError}, + utils::logging_setup, +}; +use defguard_client_service_locations::{ + windows::{watch_for_login_logoff, watch_for_network_change}, + ServiceLocationError, ServiceLocationManager, }; static SERVICE_NAME: &str = "DefguardService"; diff --git a/src-tauri/deny.toml b/src-tauri/deny.toml index fa04523f..118905c8 100644 --- a/src-tauri/deny.toml +++ b/src-tauri/deny.toml @@ -137,6 +137,27 @@ exceptions = [ { allow = [ "AGPL-3.0-or-later", ], crate = "defguard-client" }, + { allow = [ + "AGPL-3.0-or-later", + ], crate = "defguard-client-core" }, + { allow = [ + "AGPL-3.0-or-later", + ], crate = "defguard-service" }, + { allow = [ + "AGPL-3.0-or-later", + ], crate = "defguard-client-proto" }, + { allow = [ + "AGPL-3.0-or-later", + ], crate = "defguard-posture" }, + { allow = [ + "AGPL-3.0-or-later", + ], crate = "defguard-provisioning" }, + { allow = [ + "AGPL-3.0-or-later", + ], crate = "defguard-config-sync" }, + { allow = [ + "AGPL-3.0-or-later", + ], crate = "defguard-service-locations" }, ] # Some crates don't have (easily) machine readable licensing information, diff --git a/src-tauri/src/enterprise/LICENSE.md b/src-tauri/enterprise/LICENSE.md similarity index 100% rename from src-tauri/src/enterprise/LICENSE.md rename to src-tauri/enterprise/LICENSE.md diff --git a/src-tauri/enterprise/config-sync/Cargo.toml b/src-tauri/enterprise/config-sync/Cargo.toml new file mode 100644 index 00000000..4fd6e366 --- /dev/null +++ b/src-tauri/enterprise/config-sync/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "defguard-client-config-sync" +description = "Real-time configuration sync for the Defguard desktop client" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license-file = "../LICENSE.md" +rust-version.workspace = true +version.workspace = true + +[dependencies] +defguard-client-core = { path = "../../core" } +defguard-client-proto = { path = "../../client-proto" } +defguard-client-service-locations = { path = "../service-locations" } +log.workspace = true +reqwest.workspace = true +semver.workspace = true +serde.workspace = true +serde_json.workspace = true +sqlx.workspace = true +tokio = { version = "1", features = ["time"] } + +[target.'cfg(not(target_os = "macos"))'.dependencies] +tonic = { workspace = true } diff --git a/src-tauri/enterprise/config-sync/src/commands.rs b/src-tauri/enterprise/config-sync/src/commands.rs new file mode 100644 index 00000000..d8775528 --- /dev/null +++ b/src-tauri/enterprise/config-sync/src/commands.rs @@ -0,0 +1,271 @@ +use std::collections::HashSet; + +use defguard_client_core::{ + database::models::{ + instance::{ClientTrafficPolicy, Instance}, + location::{infer_mfa_method, Location}, + wireguard_keys::WireguardKeys, + Id, NoId, + }, + error::Error, + into_location, +}; +use defguard_client_proto::defguard::{ + client::v1::{DeleteServiceLocationsRequest, SaveServiceLocationsRequest}, + client_types::DeviceConfigResponse, +}; +use sqlx::{Sqlite, SqliteExecutor, Transaction}; + +#[cfg(not(target_os = "macos"))] +use defguard_client_core::connection::daemon_client::DAEMON_CLIENT; +use defguard_client_service_locations::to_service_location; + +pub async fn locations_changed( + transaction: &mut Transaction<'_, Sqlite>, + instance: &Instance, + device_config: &DeviceConfigResponse, +) -> Result { + let db_locations = Location::find_by_instance_id(transaction.as_mut(), instance.id, true) + .await? + .into_iter() + .map(|location| { + let mut new_location = Location::::from(location); + new_location.route_all_traffic = false; + new_location.mfa_method = infer_mfa_method(new_location.location_mfa_mode, None); + new_location + }) + .collect::>(); + let core_locations: HashSet = device_config + .configs + .iter() + .map(|config| into_location(config.clone(), instance.id)) + .collect::>(); + + Ok(db_locations != core_locations) +} + +pub async fn do_update_instance( + transaction: &mut Transaction<'_, Sqlite>, + instance: &mut Instance, + response: DeviceConfigResponse, +) -> Result<(), Error> { + log::debug!("Updating instance {instance}"); + let locations_changed_val = locations_changed(transaction, instance, &response).await?; + let instance_info = response + .instance + .expect("Missing instance info in device config response"); + instance.name = instance_info.name; + instance.url = instance_info.url; + instance.proxy_url = instance_info.proxy_url; + instance.username = instance_info.username; + let policy = instance_info.client_traffic_policy.into(); + if instance.client_traffic_policy != policy && policy == ClientTrafficPolicy::DisableAllTraffic + { + log::debug!("Disabling all traffic for all locations of instance {instance}"); + Location::disable_all_traffic_for_all(transaction.as_mut(), instance.id).await?; + log::debug!("Disabled all traffic for all locations of instance {instance}"); + } + instance.client_traffic_policy = instance_info.client_traffic_policy.into(); + instance.openid_display_name = instance_info.openid_display_name; + instance.uuid = instance_info.id; + if response.token.is_some() { + instance.token = response.token; + log::debug!("Set polling token for instance {}", instance.name); + } else { + log::debug!( + "No polling token received for instance {}, not updating", + instance.name + ); + } + instance.save(transaction.as_mut()).await?; + log::debug!( + "A new base configuration has been applied to instance {instance}, even if nothing changed" + ); + + let mut service_locations = Vec::new(); + + if locations_changed_val { + log::debug!( + "Updating locations for instance {}({}).", + instance.name, + instance.id + ); + let mut current_locations = + Location::find_by_instance_id(transaction.as_mut(), instance.id, true).await?; + for dev_config in response.configs { + let new_location = into_location(dev_config, instance.id); + + let saved_location = if let Some(position) = current_locations + .iter() + .position(|loc| loc.network_id == new_location.network_id) + { + let mut current_location = current_locations.remove(position); + log::debug!( + "Updating existing location {}({}) for instance {}({}).", + current_location.name, + current_location.id, + instance.name, + instance.id, + ); + current_location.name = new_location.name; + current_location.address = new_location.address; + current_location.pubkey = new_location.pubkey; + current_location.endpoint = new_location.endpoint; + current_location.allowed_ips = new_location.allowed_ips; + current_location.keepalive_interval = new_location.keepalive_interval; + current_location.dns = new_location.dns; + current_location.location_mfa_mode = new_location.location_mfa_mode; + current_location.service_location_mode = new_location.service_location_mode; + current_location.mfa_method = infer_mfa_method( + current_location.location_mfa_mode, + current_location.mfa_method, + ); + current_location.posture_check_required = new_location.posture_check_required; + current_location.save(transaction.as_mut()).await?; + log::info!( + "Location {current_location} configuration updated for instance {instance}" + ); + current_location + } else { + log::debug!("Creating new location {new_location} for instance {instance}"); + let new_location = new_location.save(transaction.as_mut()).await?; + log::info!("New location {new_location} created for instance {instance}"); + new_location + }; + + if saved_location.is_service_location() { + log::debug!( + "Adding service location {}({}) for instance {}({}) to be saved to the daemon.", + saved_location.name, + saved_location.id, + instance.name, + instance.id, + ); + service_locations.push(to_service_location(&saved_location)?); + } + } + + log::debug!("Removing locations for instance {instance}"); + for removed_location in current_locations { + removed_location.delete(transaction.as_mut()).await?; + log::info!( + "Removed location {removed_location} for instance {instance} during instance update" + ); + } + log::debug!("Finished updating locations for instance {instance}"); + } else { + log::info!("Locations for instance {instance} didn't change. Not updating them."); + } + + if service_locations.is_empty() { + log::debug!( + "No service locations for instance {}({}), removing all existing service locations.", + instance.name, + instance.id + ); + + #[cfg(not(target_os = "macos"))] + { + let delete_request = DeleteServiceLocationsRequest { + instance_id: instance.uuid.clone(), + }; + DAEMON_CLIENT + .clone() + .delete_service_locations(delete_request) + .await + .map_err(|err| { + log::error!( + "Error while deleting service locations from the daemon for instance {}({ \ + }): {err}", + instance.name, + instance.id, + ); + Error::InternalError(err.to_string()) + })?; + log::debug!( + "Successfully removed all service locations from daemon for instance {}({})", + instance.name, + instance.id + ); + } + } else { + log::debug!( + "Processing {} service location(s) for instance {}({})", + service_locations.len(), + instance.name, + instance.id + ); + + #[cfg(not(target_os = "macos"))] + { + let private_key = WireguardKeys::find_by_instance_id(transaction.as_mut(), instance.id) + .await? + .ok_or(Error::NotFound)? + .prvkey; + + let save_request = SaveServiceLocationsRequest { + service_locations: service_locations.clone(), + instance_id: instance.uuid.clone(), + private_key, + }; + + log::debug!( + "Sending request to daemon to save {} service location(s) for instance {}({})", + save_request.service_locations.len(), + instance.name, + instance.id + ); + + DAEMON_CLIENT + .clone() + .save_service_locations(save_request) + .await + .map_err(|err| { + log::error!( + "Error while saving service locations to the daemon for instance {}({}): \ + {err}", + instance.name, + instance.id, + ); + Error::InternalError(err.to_string()) + })?; + + log::info!( + "Successfully saved {} service location(s) to daemon for instance {}({})", + service_locations.len(), + instance.name, + instance.id + ); + + log::debug!( + "Completed processing all service locations for instance {}({})", + instance.name, + instance.id + ); + } + } + + Ok(()) +} + +pub async fn disable_enterprise_features<'e, E>( + instance: &mut Instance, + executor: E, +) -> Result<(), Error> +where + E: SqliteExecutor<'e>, +{ + log::debug!( + "Disabling enterprise features for instance {}({})", + instance.name, + instance.id + ); + instance.client_traffic_policy = ClientTrafficPolicy::None; + instance.save(executor).await?; + log::debug!( + "Disabled enterprise features for instance {}({})", + instance.name, + instance.id + ); + Ok(()) +} diff --git a/src-tauri/enterprise/config-sync/src/lib.rs b/src-tauri/enterprise/config-sync/src/lib.rs new file mode 100644 index 00000000..c6c6b4ea --- /dev/null +++ b/src-tauri/enterprise/config-sync/src/lib.rs @@ -0,0 +1,300 @@ +#[macro_use] +extern crate log; + +use std::cmp::Ordering; +use std::str::FromStr; + +pub mod commands; + +use defguard_client_core::{ + database::models::{instance::Instance, Id}, + error::Error, + proxy::post_with_headers, + version::{MIN_CORE_VERSION, MIN_PROXY_VERSION}, +}; +use defguard_client_proto::defguard::client_types::{InstanceInfoRequest, InstanceInfoResponse}; +use reqwest::{StatusCode, Url}; +use semver::Version; +use serde::Serialize; +use sqlx::{Sqlite, Transaction}; + +use crate::commands::disable_enterprise_features; + +static POLLING_ENDPOINT: &str = "/api/v1/poll"; + +const CORE_VERSION_HEADER: &str = "defguard-core-version"; +const CORE_CONNECTED_HEADER: &str = "defguard-core-connected"; +const PROXY_VERSION_HEADER: &str = "defguard-component-version"; + +/// Result of a successful config fetch from the proxy. +#[derive(Debug)] +pub struct FetchedConfig { + pub response: InstanceInfoResponse, + pub version_mismatch: Option, +} + +/// Payload emitted when a version mismatch is detected. +#[derive(Clone, Debug, Serialize)] +pub struct VersionMismatchPayload { + pub instance_name: String, + pub instance_id: Id, + pub core_version: String, + pub proxy_version: String, + pub core_required_version: String, + pub proxy_required_version: String, + pub core_compatible: bool, + pub proxy_compatible: bool, +} + +/// Talks to the proxy for a single instance: builds the request, POSTs it, +/// handles 402 PAYMENT_REQUIRED by disabling enterprise features, parses the +/// response, and checks the version headers. +/// +/// Does **not** apply config changes or emit events — those are the caller's +/// responsibility. +pub async fn fetch_instance_config( + transaction: &mut Transaction<'_, Sqlite>, + instance: &mut Instance, +) -> Result { + debug!("Getting config from core for instance {}", instance.name); + + let request = build_request(instance)?; + let url = Url::from_str(&instance.proxy_url) + .and_then(|url| url.join(POLLING_ENDPOINT)) + .map_err(|_| { + Error::InternalError(format!( + "Can't build polling url: {}/{POLLING_ENDPOINT}", + instance.proxy_url + )) + })?; + let response = post_with_headers(url, &request).await.map_err(|err| { + Error::InternalError(format!( + "HTTP request failed for instance {}({}), url: {}, {err}", + instance.name, instance.id, instance.proxy_url + )) + })?; + debug!( + "Got the following config response for instance {} from core: {response:?}", + instance.name + ); + + let version_mismatch = check_min_version(&response, instance); + + // Return early if the enterprise features are disabled in the core + if response.status() == StatusCode::PAYMENT_REQUIRED { + debug!( + "Instance {}({}) has enterprise features disabled, checking if this state is reflected \ + on our end.", + instance.name, instance.id + ); + if instance.enterprise_enabled { + info!( + "Instance {}({}) has enterprise features disabled, but we have them enabled, \ + disabling.", + instance.name, instance.id + ); + disable_enterprise_features(instance, transaction.as_mut()).await?; + } else { + debug!( + "Instance {}({}) has enterprise features disabled, and we have them disabled as \ + well, no action needed", + instance.name, instance.id + ); + } + return Err(Error::CoreNotEnterprise); + } + + // Parse the response + debug!( + "Parsing the config response for instance {}.", + instance.name + ); + let response: InstanceInfoResponse = response.json().await.map_err(|err| { + Error::InternalError(format!( + "Failed to parse InstanceInfoResponse for instance {}({}): {err}", + instance.name, instance.id, + )) + })?; + + if response.device_config.is_none() { + return Err(Error::InternalError( + "Device config not present in response".to_string(), + )); + } + + debug!("Parsed the config for instance {}", instance.name); + trace!("Parsed config: {:?}", response.device_config); + + Ok(FetchedConfig { + response, + version_mismatch, + }) +} + +/// Checks if config has changed compared to what's in the database. +pub async fn config_changed( + transaction: &mut Transaction<'_, Sqlite>, + instance: &Instance, + device_config: &defguard_client_proto::defguard::client_types::DeviceConfigResponse, +) -> Result { + debug!( + "Checking if config and any of the locations changed for instance {}({})", + instance.name, instance.id + ); + let locations_changed = + commands::locations_changed(transaction, instance, device_config).await?; + let info_changed = match &device_config.instance { + Some(info) => instance != info, + None => false, + }; + debug!( + "Did the locations change?: {locations_changed}. Did the instance information change?: \ + {info_changed}" + ); + Ok(locations_changed || info_changed) +} + +// --- private helpers ------------------------------------------------------- + +/// Retrieves token to build InstanceInfoRequest +fn build_request(instance: &Instance) -> Result { + let token = instance.token.as_ref().ok_or_else(|| Error::NoToken)?; + + Ok(InstanceInfoRequest { + token: (*token).clone(), + }) +} + +/// Checks response headers for version compatibility. +/// Pure — returns `Some(payload)` when versions are incompatible, `None` when +/// everything is compatible or headers are missing. +fn check_min_version( + response: &reqwest::Response, + instance: &Instance, +) -> Option { + let detected_core_version: String; + let detected_proxy_version: String; + + let defguard_core_connected: Option = response + .headers() + .get(CORE_CONNECTED_HEADER) + .and_then(|v| { + debug!( + "Defguard core connection status header for instance {}({}): {v:?}", + instance.name, instance.id + ); + v.to_str().ok() + }) + .and_then(|s| s.parse().ok()); + + let core_compatible = if let Some(core_version) = response.headers().get(CORE_VERSION_HEADER) { + if let Ok(core_version) = core_version.to_str() { + if let Ok(core_version) = Version::from_str(core_version) { + detected_core_version = core_version.to_string(); + core_version.cmp_precedence(&MIN_CORE_VERSION) != Ordering::Less + } else { + warn!( + "Core version header: invalid semver string in response for instance {}({}): \ + '{core_version}'", + instance.name, instance.id + ); + detected_core_version = core_version.to_string(); + false + } + } else { + warn!( + "Core version header: invalid string in response for instance {}({}): \ + '{core_version:?}'", + instance.name, instance.id + ); + detected_core_version = "unknown".to_string(); + false + } + } else { + warn!( + "Core version header not present in response for instance {}({})", + instance.name, instance.id + ); + detected_core_version = "unknown".to_string(); + false + }; + + let proxy_compatible = if let Some(proxy_version) = response.headers().get(PROXY_VERSION_HEADER) + { + if let Ok(proxy_version) = proxy_version.to_str() { + if let Ok(proxy_version) = Version::from_str(proxy_version) { + detected_proxy_version = proxy_version.to_string(); + proxy_version.cmp_precedence(&MIN_PROXY_VERSION) != Ordering::Less + } else { + warn!( + "Proxy version header not a valid semver string in response for instance \ + {}({}): '{proxy_version}'", + instance.name, instance.id + ); + detected_proxy_version = proxy_version.to_string(); + false + } + } else { + warn!( + "Proxy version header not a valid string in response for instance {}({}): \ + '{proxy_version:?}'", + instance.name, instance.id + ); + detected_proxy_version = "unknown".to_string(); + false + } + } else { + warn!( + "Proxy version header not present in response for instance {}({})", + instance.name, instance.id + ); + detected_proxy_version = "unknown".to_string(); + false + }; + + let should_inform = match defguard_core_connected { + Some(true) => { + debug!( + "Defguard core is connected for instance {}({})", + instance.name, instance.id + ); + true + } + Some(false) => { + info!( + "Defguard core is not connected for instance {}({})", + instance.name, instance.id + ); + false + } + None => { + debug!( + "Defguard core connection status unknown for instance {}({})", + instance.name, instance.id + ); + true + } + }; + + if should_inform && (!core_compatible || !proxy_compatible) { + warn!( + "Instance {} is running incompatible versions: core {detected_core_version}, proxy \ + {detected_proxy_version}. Required versions: core >= {MIN_CORE_VERSION}, proxy >= \ + {MIN_PROXY_VERSION}", + instance.name, + ); + + Some(VersionMismatchPayload { + instance_name: instance.name.clone(), + instance_id: instance.id, + core_version: detected_core_version, + proxy_version: detected_proxy_version, + core_required_version: MIN_CORE_VERSION.to_string(), + proxy_required_version: MIN_PROXY_VERSION.to_string(), + core_compatible, + proxy_compatible, + }) + } else { + None + } +} diff --git a/src-tauri/enterprise/posture/Cargo.toml b/src-tauri/enterprise/posture/Cargo.toml new file mode 100644 index 00000000..754627d7 --- /dev/null +++ b/src-tauri/enterprise/posture/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "defguard-client-posture" +description = "Device posture checks for the Defguard desktop client" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license-file = "../LICENSE.md" +rust-version.workspace = true +version.workspace = true + +[dependencies] +defguard-client-core = { path = "../../core" } +defguard-client-proto = { path = "../../client-proto" } +log.workspace = true +reqwest.workspace = true +serde.workspace = true +serde_json.workspace = true + +[target.'cfg(target_os = "linux")'.dependencies] +sysinfo = { version = "0.39", default-features = false, features = ["system"] } + +[target.'cfg(target_os = "macos")'.dependencies] +sysinfo = { version = "0.39", default-features = false, features = ["system"] } + +[target.'cfg(windows)'.dependencies] +sysinfo = { version = "0.39", default-features = false, features = ["system"] } diff --git a/src-tauri/src/enterprise/inspector/linux.rs b/src-tauri/enterprise/posture/src/inspector/linux.rs similarity index 100% rename from src-tauri/src/enterprise/inspector/linux.rs rename to src-tauri/enterprise/posture/src/inspector/linux.rs diff --git a/src-tauri/src/enterprise/inspector/macos.rs b/src-tauri/enterprise/posture/src/inspector/macos.rs similarity index 100% rename from src-tauri/src/enterprise/inspector/macos.rs rename to src-tauri/enterprise/posture/src/inspector/macos.rs diff --git a/src-tauri/enterprise/posture/src/inspector/mod.rs b/src-tauri/enterprise/posture/src/inspector/mod.rs new file mode 100644 index 00000000..32d709bb --- /dev/null +++ b/src-tauri/enterprise/posture/src/inspector/mod.rs @@ -0,0 +1,135 @@ +#[cfg(target_os = "linux")] +pub(crate) mod linux; +#[cfg(target_os = "macos")] +pub(crate) mod macos; +#[cfg(test)] +mod tests; +#[cfg(windows)] +pub(crate) mod windows; + +use std::env::consts::OS; + +use sysinfo::System; + +use defguard_client_core::version::PKG_VERSION; +use defguard_client_proto::defguard::enterprise::posture::v2::{ + BoolCheck, DevicePostureData, Int32Check, StringCheck, UnavailableReason, +}; + +/// Returns the operating system name. +fn os_name() -> Result { + System::name().ok_or(UnavailableReason::DetectionFailed) +} + +/// Returns the operating system version. +fn os_version() -> Result { + #[cfg(windows)] + { + // Windows can report versions like "11 (26200)"; core expects a parseable major. + System::os_version() + .and_then(|version| version.split_whitespace().next().map(ToString::to_string)) + .ok_or(UnavailableReason::DetectionFailed) + } + + #[cfg(not(windows))] + { + System::os_version().ok_or(UnavailableReason::DetectionFailed) + } +} + +/// Returns the Linux kernel version. +fn linux_kernel_version() -> Result { + #[cfg(target_os = "linux")] + { + System::kernel_version().ok_or(UnavailableReason::DetectionFailed) + } + + #[cfg(not(target_os = "linux"))] + { + Err(UnavailableReason::NotApplicable) + } +} + +/// Returns the disk encryption status, preferably for the system volume. +fn disk_encryption_status() -> Result { + #[cfg(target_os = "macos")] + { + macos::disk_encryption_status() + } + + #[cfg(windows)] + { + windows::disk_encryption_status() + } + + #[cfg(target_os = "linux")] + { + linux::disk_encryption_status() + } +} + +/// Returns the antivirus status. +fn anti_virus_status() -> Result { + #[cfg(windows)] + { + windows::anti_virus_status() + } + + #[cfg(not(windows))] + { + Err(UnavailableReason::NotApplicable) + } +} + +/// Checks whether the computer is part of a domain. +fn part_of_domain() -> Result { + #[cfg(windows)] + { + windows::part_of_domain() + } + + #[cfg(not(windows))] + { + Err(UnavailableReason::NotApplicable) + } +} + +/// Returns the device integrity status. +fn device_integrity() -> Result { + #[cfg(target_os = "macos")] + { + macos::system_integrity_status() + } + + #[cfg(not(target_os = "macos"))] + Err(UnavailableReason::NotApplicable) +} + +/// Returns the number of days since the last installed Windows security update. +fn security_update_age_days() -> Result { + #[cfg(windows)] + { + windows::security_update_age_days() + } + + #[cfg(not(windows))] + { + Err(UnavailableReason::NotApplicable) + } +} + +#[must_use] +pub(crate) fn device_posture_data() -> DevicePostureData { + DevicePostureData { + defguard_client_version: PKG_VERSION.to_owned(), + os_type: OS.to_string(), + os_name: Some(StringCheck::from(os_name())), + os_version: Some(StringCheck::from(os_version())), + disk_encryption: Some(BoolCheck::from(disk_encryption_status())), + antivirus_present: Some(BoolCheck::from(anti_virus_status())), + windows_ad_domain_joined: Some(BoolCheck::from(part_of_domain())), + windows_security_update_age_days: Some(Int32Check::from(security_update_age_days())), + linux_kernel_version: Some(StringCheck::from(linux_kernel_version())), + device_integrity: Some(BoolCheck::from(device_integrity())), + } +} diff --git a/src-tauri/src/enterprise/inspector/tests/linux.rs b/src-tauri/enterprise/posture/src/inspector/tests/linux.rs similarity index 100% rename from src-tauri/src/enterprise/inspector/tests/linux.rs rename to src-tauri/enterprise/posture/src/inspector/tests/linux.rs diff --git a/src-tauri/src/enterprise/inspector/tests/macos.rs b/src-tauri/enterprise/posture/src/inspector/tests/macos.rs similarity index 100% rename from src-tauri/src/enterprise/inspector/tests/macos.rs rename to src-tauri/enterprise/posture/src/inspector/tests/macos.rs diff --git a/src-tauri/src/enterprise/inspector/tests/mod.rs b/src-tauri/enterprise/posture/src/inspector/tests/mod.rs similarity index 100% rename from src-tauri/src/enterprise/inspector/tests/mod.rs rename to src-tauri/enterprise/posture/src/inspector/tests/mod.rs diff --git a/src-tauri/src/enterprise/inspector/tests/windows.rs b/src-tauri/enterprise/posture/src/inspector/tests/windows.rs similarity index 100% rename from src-tauri/src/enterprise/inspector/tests/windows.rs rename to src-tauri/enterprise/posture/src/inspector/tests/windows.rs diff --git a/src-tauri/src/enterprise/inspector/windows.rs b/src-tauri/enterprise/posture/src/inspector/windows.rs similarity index 100% rename from src-tauri/src/enterprise/inspector/windows.rs rename to src-tauri/enterprise/posture/src/inspector/windows.rs diff --git a/src-tauri/enterprise/posture/src/lib.rs b/src-tauri/enterprise/posture/src/lib.rs new file mode 100644 index 00000000..da9d2d7b --- /dev/null +++ b/src-tauri/enterprise/posture/src/lib.rs @@ -0,0 +1,7 @@ +#[macro_use] +extern crate log; + +pub mod inspector; +pub mod posture; + +pub use posture::{authorize_posture_session, get_posture_data}; diff --git a/src-tauri/src/enterprise/posture.rs b/src-tauri/enterprise/posture/src/posture.rs similarity index 87% rename from src-tauri/src/enterprise/posture.rs rename to src-tauri/enterprise/posture/src/posture.rs index f2641a1e..5f113b8d 100644 --- a/src-tauri/src/enterprise/posture.rs +++ b/src-tauri/enterprise/posture/src/posture.rs @@ -1,20 +1,22 @@ -use reqwest::StatusCode; +use reqwest::{StatusCode, Url}; use serde::Deserialize; #[cfg(windows)] -use crate::service::client::DAEMON_CLIENT; -use crate::{ +use defguard_client_core::connection::daemon_client::DAEMON_CLIENT; +use defguard_client_core::{ database::{ models::{instance::Instance, location::Location, wireguard_keys::WireguardKeys, Id}, DB_POOL, }, error::Error, - service::proto::defguard::enterprise::posture::v2::{ - DevicePostureCheckRequest, DevicePostureCheckResponse, DevicePostureData, - }, - utils::post_with_headers, + proxy::post_with_headers, +}; +use defguard_client_proto::defguard::enterprise::posture::v2::{ + DevicePostureCheckRequest, DevicePostureCheckResponse, DevicePostureData, }; +use crate::inspector::device_posture_data; + const POSTURE_ENDPOINT: &str = "/api/v1/posture/connect"; /// Collects device posture data, sends it to the proxy, and returns the runtime preshared key. @@ -40,7 +42,7 @@ pub async fn authorize_posture_session(location: &Location) -> Result Result { } #[cfg(not(windows))] { - Ok(DevicePostureData::new()) + Ok(device_posture_data()) } } diff --git a/src-tauri/enterprise/provisioning/Cargo.toml b/src-tauri/enterprise/provisioning/Cargo.toml new file mode 100644 index 00000000..57b01fa3 --- /dev/null +++ b/src-tauri/enterprise/provisioning/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "defguard-client-provisioning" +description = "Zero-touch provisioning for the Defguard desktop client" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license-file = "../LICENSE.md" +rust-version.workspace = true +version.workspace = true + +[dependencies] +defguard-client-core = { path = "../../core" } +log.workspace = true +serde.workspace = true +serde_json.workspace = true diff --git a/src-tauri/enterprise/provisioning/src/lib.rs b/src-tauri/enterprise/provisioning/src/lib.rs new file mode 100644 index 00000000..cdd3e3dc --- /dev/null +++ b/src-tauri/enterprise/provisioning/src/lib.rs @@ -0,0 +1,66 @@ +use std::{fmt, fs, path::Path}; + +use serde::{Deserialize, Serialize}; + +const CONFIG_FILE_NAME: &str = "provisioning.json"; + +#[derive(Clone, Deserialize, Serialize)] +pub struct ProvisioningConfig { + pub enrollment_url: String, + pub enrollment_token: String, +} + +impl fmt::Debug for ProvisioningConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self { + enrollment_url, + enrollment_token: _, + } = self; + + f.debug_struct("ProvisioningConfig") + .field("enrollment_url", enrollment_url) + .field("enrollment_token", &"***") + .finish() + } +} + +impl ProvisioningConfig { + /// Load configuration from a file at `path`. + fn load(path: &Path) -> Option { + let file_content = match fs::read_to_string(path) { + Ok(content) => content, + Err(err) => { + log::warn!( + "Failed to open provisioning configuration file at {}. Error details: {err}", + path.display() + ); + return None; + } + }; + + let file_content = file_content.trim_start_matches('\u{FEFF}'); + + match serde_json::from_str::(file_content) { + Ok(config) => Some(config), + Err(err) => { + log::warn!( + "Failed to parse provisioning configuration file at {}. Error details: {err}", + path.display() + ); + None + } + } + } +} + +/// Try to find and load the provisioning configuration from the given app data directory. +#[must_use] +pub fn try_get_provisioning_config(app_data_dir: &Path) -> Option { + log::debug!( + "Trying to find provisioning config in {}", + app_data_dir.display() + ); + + let config_file_path = app_data_dir.join(CONFIG_FILE_NAME); + ProvisioningConfig::load(&config_file_path) +} diff --git a/src-tauri/enterprise/service-locations/Cargo.toml b/src-tauri/enterprise/service-locations/Cargo.toml new file mode 100644 index 00000000..833ad0a7 --- /dev/null +++ b/src-tauri/enterprise/service-locations/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "defguard-client-service-locations" +description = "Service location management for the Defguard client daemon" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license-file = "../LICENSE.md" +rust-version.workspace = true +version.workspace = true + +[dependencies] +defguard-client-common = { path = "../../common" } +defguard-client-core = { path = "../../core" } +defguard-client-proto = { path = "../../client-proto" } +defguard_wireguard_rs = { workspace = true } +base64.workspace = true +log.workspace = true +prost = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } + +[target.'cfg(windows)'.dependencies] +known-folders = "1.4" +windows = "0.62" +windows-acl = "0.3" +windows-service = "0.8" +windows-sys = "0.61" diff --git a/src-tauri/enterprise/service-locations/src/lib.rs b/src-tauri/enterprise/service-locations/src/lib.rs new file mode 100644 index 00000000..e1393199 --- /dev/null +++ b/src-tauri/enterprise/service-locations/src/lib.rs @@ -0,0 +1,121 @@ +use std::{collections::HashMap, fmt}; + +use defguard_client_core::database::models::{ + location::{Location, ServiceLocationMode}, + Id, +}; +use defguard_client_core::error::Error as CoreError; +use defguard_client_proto::defguard::client::v1::ServiceLocation; +use defguard_wireguard_rs::{error::WireguardInterfaceError, WGApi}; +use serde::{Deserialize, Serialize}; + +#[cfg(windows)] +pub mod windows; + +#[derive(Debug, thiserror::Error)] +pub enum ServiceLocationError { + #[error("Error occurred while initializing service location API: {0}")] + InitError(String), + #[error("Failed to load service location storage: {0}")] + LoadError(String), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + DecodeError(#[from] base64::DecodeError), + #[error(transparent)] + WireGuardError(#[from] WireguardInterfaceError), + #[error(transparent)] + AddrParseError(#[from] defguard_wireguard_rs::net::IpAddrParseError), + #[error("WireGuard interface error: {0}")] + InterfaceError(String), + #[error(transparent)] + JsonError(#[from] serde_json::Error), + #[error(transparent)] + ProtoEnumError(#[from] prost::UnknownEnumValue), + #[cfg(windows)] + #[error(transparent)] + WindowsServiceError(#[from] windows_service::Error), +} + +#[allow(dead_code)] +#[derive(Default)] +pub struct ServiceLocationManager { + // Interface name: WireGuard API instance + wgapis: HashMap, + // Instance ID: Service locations connected under that instance + connected_service_locations: HashMap>, +} + +#[allow(dead_code)] +#[derive(Serialize, Deserialize)] +pub struct ServiceLocationData { + pub service_locations: Vec, + pub instance_id: String, + pub private_key: String, +} + +#[allow(dead_code)] +pub struct SingleServiceLocationData { + pub service_location: ServiceLocation, + pub instance_id: String, + pub private_key: String, +} + +impl fmt::Debug for ServiceLocationData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServiceLocationData") + .field("service_locations", &self.service_locations) + .field("instance_id", &self.instance_id) + .field("private_key", &"***") + .finish() + } +} + +impl fmt::Debug for SingleServiceLocationData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SingleServiceLocationData") + .field("service_locations", &self.service_location) + .field("instance_id", &self.instance_id) + .field("private_key", &"***") + .finish() + } +} + +pub fn to_service_location(location: &Location) -> Result { + if !location.is_service_location() { + log::warn!( + "Location {location} is not a service location, so it can't be converted to one." + ); + return Err(CoreError::ConversionError(format!( + "Failed to convert location {location} to a service location as it's either not marked \ + as one or has MFA enabled." + ))); + } + + let mode = match location.service_location_mode { + ServiceLocationMode::Disabled => { + log::warn!( + "Location {location} has an invalid service location mode, so it can't be converted to \ + one." + ); + return Err(CoreError::ConversionError(format!( + "Location {location} has an invalid service location mode ({:?}), so it can't be \ + converted to one.", + location.service_location_mode + ))); + } + ServiceLocationMode::PreLogon => 0, + ServiceLocationMode::AlwaysOn => 1, + }; + + Ok(ServiceLocation { + name: location.name.clone(), + address: location.address.clone(), + pubkey: location.pubkey.clone(), + endpoint: location.endpoint.clone(), + allowed_ips: location.allowed_ips.clone(), + dns: location.dns.clone().unwrap_or_default(), + keepalive_interval: location.keepalive_interval.try_into().unwrap_or(0), + mode, + }) +} diff --git a/src-tauri/enterprise/service-locations/src/windows.rs b/src-tauri/enterprise/service-locations/src/windows.rs new file mode 100644 index 00000000..2def12b2 --- /dev/null +++ b/src-tauri/enterprise/service-locations/src/windows.rs @@ -0,0 +1,970 @@ +use std::{ + collections::HashMap, + ffi::OsStr, + fs::{self, create_dir_all}, + path::PathBuf, + result::Result, + str::FromStr, + sync::{Arc, RwLock}, + time::Duration, +}; + +use defguard_client_common::{dns_borrow, find_free_tcp_port, get_interface_name}; +use defguard_client_proto::defguard::client::v1::{ServiceLocation, ServiceLocationMode}; +use defguard_wireguard_rs::{ + key::Key, net::IpAddrMask, peer::Peer, InterfaceConfiguration, WGApi, WireguardInterfaceApi, +}; +use known_folders::get_known_folder_path; +use log::{debug, error, warn}; +use windows::{ + core::PSTR, + Win32::System::RemoteDesktop::{ + self, WTSQuerySessionInformationA, WTSWaitSystemEvent, WTS_CURRENT_SERVER_HANDLE, + WTS_EVENT_LOGOFF, WTS_EVENT_LOGON, WTS_SESSION_INFOA, + }, +}; +use windows_acl::acl::ACL; +use windows_sys::Win32::NetworkManagement::IpHelper::NotifyAddrChange; + +use crate::{ + ServiceLocationData, ServiceLocationError, ServiceLocationManager, SingleServiceLocationData, +}; + +const LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS: u64 = 5; +// How long to wait after a network change before attempting to connect. +// Gives DHCP time to complete and DNS to become available. +const NETWORK_STABILIZATION_DELAY: Duration = Duration::from_secs(3); +// How long to wait before restarting the network change watcher on error. +const NETWORK_CHANGE_MONITOR_RESTART_DELAY: Duration = Duration::from_secs(5); +const DEFAULT_WIREGUARD_PORT: u16 = 51820; +const DEFGUARD_DIR: &str = "Defguard"; +const SERVICE_LOCATIONS_SUBDIR: &str = "service_locations"; + +/// Watches for IP address changes on any network interface and attempts to connect to any +/// service locations that are not yet connected. This handles the case where the endpoint +/// hostname cannot be resolved at service startup because the network (e.g. Wi-Fi) is not +/// yet available. When the network comes up and an IP is assigned, this watcher fires and +/// retries the connection. +/// +/// Note: `NotifyAddrChange` also fires when WireGuard interfaces are created. This is +/// harmless because `connect_to_service_locations` skips already-connected locations. +/// +/// Runs on a dedicated OS thread because `NotifyAddrChange` is a blocking syscall. +pub(crate) fn watch_for_network_change( + service_location_manager: Arc>, +) { + loop { + // NotifyAddrChange blocks until any IP address is added or removed on any interface. + // Passing NULL for both handle and overlapped selects the synchronous (blocking) mode. + let result = unsafe { NotifyAddrChange(std::ptr::null_mut(), std::ptr::null()) }; + + if result != 0 { + error!("NotifyAddrChange failed with error code: {result}"); + std::thread::sleep(NETWORK_CHANGE_MONITOR_RESTART_DELAY); + continue; + } + + debug!( + "Network address change detected, waiting {NETWORK_STABILIZATION_DELAY:?}s for \ + network to stabilize before attempting service location connections..." + ); + std::thread::sleep(NETWORK_STABILIZATION_DELAY); + + debug!("Attempting to connect to service locations after network change"); + let connect_result = service_location_manager + .write() + .unwrap() + .connect_to_service_locations(); + match connect_result { + Ok(_) => { + debug!("Service location connect attempt after network change completed"); + } + Err(err) => { + warn!("Failed to connect to service locations after network change: {err}"); + } + } + } +} + +/// Watches for user logon/logoff events and connects/disconnects pre-logon service locations +/// accordingly. +/// +/// Runs on a dedicated OS thread because `WTSWaitSystemEvent` is a blocking syscall. +pub(crate) fn watch_for_login_logoff( + service_location_manager: Arc>, +) -> Result<(), ServiceLocationError> { + loop { + let mut event_flags: u32 = 0; + let success = unsafe { + WTSWaitSystemEvent( + Some(WTS_CURRENT_SERVER_HANDLE), + WTS_EVENT_LOGON | WTS_EVENT_LOGOFF, + &mut event_flags, + ) + }; + + match success { + Ok(_) => { + debug!("Waiting for system event returned with event_flags: 0x{event_flags:x}"); + } + Err(err) => { + error!("Failed waiting for login/logoff event: {err:?}"); + std::thread::sleep(Duration::from_secs(LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS)); + continue; + } + }; + + if event_flags & WTS_EVENT_LOGON != 0 { + debug!("Detected user logon, attempting to auto-disconnect from service locations."); + service_location_manager + .write() + .unwrap() + .disconnect_service_locations(Some(ServiceLocationMode::PreLogon))?; + } + if event_flags & WTS_EVENT_LOGOFF != 0 { + debug!("Detected user logoff, attempting to auto-connect to service locations."); + service_location_manager + .write() + .unwrap() + .connect_to_service_locations()?; + } + } +} + +fn setup_wgapi(ifname: &str) -> Result { + WGApi::new(ifname).map_err(|err| { + let msg = format!("Failed to setup WireGuard API for interface {ifname}: {err}"); + log::error!("{msg}"); + ServiceLocationError::InterfaceError(msg) + }) +} + +fn get_shared_directory() -> Result { + match get_known_folder_path(known_folders::KnownFolder::ProgramData) { + Some(mut path) => { + path.push(DEFGUARD_DIR); + path.push(SERVICE_LOCATIONS_SUBDIR); + Ok(path) + } + None => Err(ServiceLocationError::LoadError( + "Could not find ProgramData known folder".to_string(), + )), + } +} + +fn set_protected_acls(path: &str) -> Result<(), ServiceLocationError> { + debug!("Setting secure ACLs on: {path}"); + + const SYSTEM_SID: &str = "S-1-5-18"; // NT AUTHORITY\SYSTEM + const ADMINISTRATORS_SID: &str = "S-1-5-32-544"; // BUILTIN\Administrators + + const FILE_ALL_ACCESS: u32 = 0x001F_01FF; + + match ACL::from_file_path(path, false) { + Ok(mut acl) => { + // Remove everything else from access + debug!("Removing all existing ACL entries for {path}"); + let all_entries = acl.all().map_err(|e| { + ServiceLocationError::LoadError(format!("Failed to get ACL entries: {e}")) + })?; + + for entry in all_entries { + if let Some(sid) = entry.sid { + if let Err(e) = acl.remove(sid.as_ptr() as *mut _, None, None) { + debug!("Note: Could not remove ACL entry (might be expected): {e}"); + } + } + } + + debug!("Cleared existing ACL entries, now adding secure entries"); + + // Add SYSTEM with full control + debug!("Adding SYSTEM with full control"); + let system_sid_result = windows_acl::helper::string_to_sid(SYSTEM_SID); + match system_sid_result { + Ok(system_sid) => { + acl.allow(system_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) + .map_err(|e| { + ServiceLocationError::LoadError(format!( + "Failed to add SYSTEM ACL: {e}" + )) + })?; + } + Err(e) => { + return Err(ServiceLocationError::LoadError(format!( + "Failed to convert SYSTEM SID: {e}" + ))); + } + } + + // Add Administrators with full control + debug!("Adding Administrators with full control"); + let admin_sid_result = windows_acl::helper::string_to_sid(ADMINISTRATORS_SID); + match admin_sid_result { + Ok(admin_sid) => { + acl.allow(admin_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) + .map_err(|e| { + ServiceLocationError::LoadError(format!( + "Failed to add Administrators ACL: {e}" + )) + })?; + } + Err(e) => { + return Err(ServiceLocationError::LoadError(format!( + "Failed to convert Administrators SID: {e}" + ))); + } + } + + debug!("Successfully set secure ACLs on {path} for SYSTEM and Administrators"); + Ok(()) + } + Err(e) => { + error!("Failed to get ACL for {path}: {e}"); + Err(ServiceLocationError::LoadError(format!( + "Failed to get ACL for {path}: {e}" + ))) + } + } +} + +fn get_instance_file_path(instance_id: &str) -> Result { + let mut path = get_shared_directory()?; + path.push(format!("{instance_id}.json")); + Ok(path) +} + +pub(crate) fn is_user_logged_in() -> bool { + debug!("Starting checking if user is logged in..."); + + unsafe { + let mut pp_sessions: *mut WTS_SESSION_INFOA = std::ptr::null_mut(); + let mut count: u32 = 0; + + debug!("Calling WTSEnumerateSessionsA..."); + let ret = RemoteDesktop::WTSEnumerateSessionsA(None, 0, 1, &mut pp_sessions, &mut count); + + match ret { + Ok(_) => { + debug!("WTSEnumerateSessionsA succeeded, found {count} sessions"); + let sessions = std::slice::from_raw_parts(pp_sessions, count as usize); + + for (index, session) in sessions.iter().enumerate() { + debug!( + "Session {index}: SessionId={}, State={:?}, WinStationName={:?}", + session.SessionId, + session.State, + std::ffi::CStr::from_ptr(session.pWinStationName.0 as *const i8) + .to_string_lossy() + ); + + if session.State == windows::Win32::System::RemoteDesktop::WTSActive { + let mut buffer = PSTR::null(); + let mut bytes_returned: u32 = 0; + + let result = WTSQuerySessionInformationA( + None, + session.SessionId, + windows::Win32::System::RemoteDesktop::WTSUserName, + &mut buffer, + &mut bytes_returned, + ); + + match result { + Ok(_) => { + if !buffer.is_null() { + let username = std::ffi::CStr::from_ptr(buffer.0 as *const i8) + .to_string_lossy() + .into_owned(); + + debug!( + "Found session {} username: {username}", + session.SessionId + ); + + windows::Win32::System::RemoteDesktop::WTSFreeMemory( + buffer.0 as *mut _, + ); + + // We found an active session with a username. + // Free the session list before returning to avoid a leak. + windows::Win32::System::RemoteDesktop::WTSFreeMemory( + pp_sessions as _, + ); + return true; + } + } + Err(err) => { + debug!( + "Failed to get username for session {}: {err:?}", + session.SessionId + ); + } + } + } + } + windows::Win32::System::RemoteDesktop::WTSFreeMemory(pp_sessions as _); + debug!("No active sessions found"); + } + Err(err) => { + error!("Failed to enumerate user sessions: {err:?}"); + debug!("WTSEnumerateSessionsA failed: {err:?}"); + } + } + } + + debug!("User is not logged in."); + false +} + +impl ServiceLocationManager { + pub fn init() -> Result { + debug!("Initializing ServiceLocationApi"); + let path = get_shared_directory()?; + + debug!("Creating directory: {path:?}"); + create_dir_all(&path)?; + + if let Some(path_str) = path.to_str() { + debug!("Setting ACLs on service locations directory"); + if let Err(e) = set_protected_acls(path_str) { + warn!("Failed to set ACLs on service locations directory: {e}. Continuing anyway."); + } + } else { + warn!("Failed to convert path to string for ACL setting"); + } + + let manager = Self { + wgapis: HashMap::new(), + connected_service_locations: HashMap::new(), + }; + + debug!("ServiceLocationApi initialized successfully"); + Ok(manager) + } + + /// Check if a specific service location is already connected + fn is_service_location_connected(&self, instance_id: &str, location_pubkey: &str) -> bool { + if let Some(locations) = self.connected_service_locations.get(instance_id) { + for location in locations { + if location.pubkey == location_pubkey { + return true; + } + } + } + false + } + + /// Add a connected service location + fn add_connected_service_location( + &mut self, + instance_id: &str, + location: &ServiceLocation, + ) -> Result<(), ServiceLocationError> { + self.connected_service_locations + .entry(instance_id.to_string()) + .or_default() + .push(location.clone()); + + debug!( + "Added connected service location for instance '{instance_id}', location '{}'", + location.name + ); + Ok(()) + } + + /// Remove connected service locations by filter (write disk-first, then memory) + fn remove_connected_service_locations( + &mut self, + filter: F, + ) -> Result<(), ServiceLocationError> + where + F: Fn(&str, &ServiceLocation) -> bool, + { + // Iterate through connected_service_locations and remove matching locations + let mut instances_to_remove = Vec::new(); + + for (instance_id, locations) in self.connected_service_locations.iter_mut() { + locations.retain(|location| !filter(instance_id, location)); + + // Mark instance for removal if it has no more locations + if locations.is_empty() { + instances_to_remove.push(instance_id.clone()); + } + } + + // Remove instances with no locations + for instance_id in instances_to_remove { + self.connected_service_locations.remove(&instance_id); + } + + debug!("Removed connected service locations matching filter"); + Ok(()) + } + + // Resets the state of the service location: + // 1. If it's an always on location, disconnects and reconnects it. + // 2. Otherwise, just disconnects it if the user is not logged in. + pub(crate) fn reset_service_location_state( + &mut self, + instance_id: &str, + location_pubkey: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Reseting the state of service location for instance_id: {instance_id}, \ + location_pubkey: {location_pubkey}" + ); + + let service_location_data = self + .load_service_location(instance_id, location_pubkey)? + .ok_or_else(|| { + ServiceLocationError::LoadError(format!( + "Service location with pubkey {} for instance {} not found", + location_pubkey, instance_id + )) + })?; + + debug!( + "Disconnecting service location for instance_id: {instance_id}, location_pubkey: \ + {location_pubkey} ({})", + service_location_data.service_location.name + ); + + self.disconnect_service_location(instance_id, location_pubkey)?; + + debug!( + "Disconnected service location for instance_id: {instance_id}, \ + location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name + ); + + debug!( + "Reconnecting service location if needed for instance_id: {instance_id}, \ + location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name + ); + + // We should reconnect only if: + // 1. It's an always on location + // 2. It's a pre-logon location and the user is not logged in + if service_location_data.service_location.mode == ServiceLocationMode::AlwaysOn as i32 + || (service_location_data.service_location.mode == ServiceLocationMode::PreLogon as i32 + && !is_user_logged_in()) + { + debug!( + "Reconnecting service location for instance_id: {instance_id}, location_pubkey: \ + {location_pubkey} ({})", + service_location_data.service_location.name + ); + self.connect_to_service_location(&service_location_data)?; + } + + debug!("Service location state reset completed."); + + Ok(()) + } + + pub(crate) fn disconnect_service_locations_by_instance( + &mut self, + instance_id: &str, + ) -> Result<(), ServiceLocationError> { + debug!("Disconnecting all service locations for instance_id: {instance_id}"); + + if let Some(locations) = self.connected_service_locations.get(instance_id) { + // Collect locations to disconnect to avoid borrowing issues + let locations_to_disconnect = locations.to_vec(); + + for location in locations_to_disconnect { + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {ifname}"); + if let Some(mut wgapi) = self.wgapis.remove(&ifname) { + if let Err(err) = wgapi.remove_interface() { + error!("Failed to remove interface {ifname}: {err}"); + } else { + debug!("Interface {ifname} removed successfully"); + } + debug!( + "Removing connected service location for instance_id: {instance_id}, \ + location_pubkey: {}", + location.pubkey + ); + debug!( + "Disconnected service location for instance_id: {instance_id}, \ + location_pubkey: {}", + location.pubkey + ); + } else { + error!("Failed to find WireGuard API for interface {ifname}"); + } + } + + self.connected_service_locations.remove(instance_id); + } else { + debug!( + "No connected service locations found for instance_id: {instance_id}. Skipping disconnect" + ); + return Ok(()); + } + + debug!("Disconnected all service locations for instance_id: {instance_id}"); + + Ok(()) + } + + pub(crate) fn disconnect_service_location( + &mut self, + instance_id: &str, + location_pubkey: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Disconnecting service location for instance_id: {instance_id}, location_pubkey: \ + {location_pubkey}" + ); + + if let Some(locations) = self.connected_service_locations.get_mut(instance_id) { + if let Some(pos) = locations + .iter() + .position(|loc| loc.pubkey == location_pubkey) + { + let location = locations.remove(pos); + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {ifname}"); + if let Some(mut wgapi) = self.wgapis.remove(&ifname) { + if let Err(err) = wgapi.remove_interface() { + error!("Failed to remove interface {ifname}: {err}"); + } else { + debug!("Interface {ifname} removed successfully."); + } + } else { + error!("Failed to find WireGuard API for interface {ifname}. "); + } + } else { + debug!( + "Service location with pubkey {location_pubkey} for instance {instance_id} is \ + not connected, skipping disconnect" + ); + return Ok(()); + } + } else { + debug!( + "No connected service locations found for instance_id: {instance_id}, skipping \ + disconnect" + ); + return Ok(()); + } + + debug!( + "Disconnected service location for instance_id: {instance_id}, location_pubkey: \ + {location_pubkey}" + ); + + Ok(()) + } + + /// Helper function to setup a WireGuard interface for a service location + fn setup_service_location_interface( + &mut self, + location: &ServiceLocation, + private_key: &str, + ) -> Result<(), ServiceLocationError> { + let peer_key = Key::from_str(&location.pubkey)?; + + let mut peer = Peer::new(peer_key.clone()); + peer.set_endpoint(&location.endpoint)?; + + peer.persistent_keepalive_interval = location.keepalive_interval.try_into().ok(); + + let allowed_ips = location + .allowed_ips + .split(',') + .map(str::to_string) + .collect::>(); + + for allowed_ip in &allowed_ips { + match IpAddrMask::from_str(allowed_ip) { + Ok(addr) => { + peer.allowed_ips.push(addr); + } + Err(err) => { + error!( + "Error parsing IP address {allowed_ip} while setting up interface for \ + location {location:?}, error details: {err}" + ); + } + } + } + + let mut addresses = Vec::new(); + + for address in location.address.split(',') { + addresses.push(IpAddrMask::from_str(address.trim())?); + } + + let config = InterfaceConfiguration { + name: location.name.clone(), + prvkey: private_key.to_string(), + addresses, + port: find_free_tcp_port().unwrap_or(DEFAULT_WIREGUARD_PORT), + peers: vec![peer.clone()], + mtu: None, + fwmark: None, // TODO: add + }; + + let ifname = location.name.clone(); + let ifname = get_interface_name(&ifname); + let mut wgapi = match setup_wgapi(&ifname) { + Ok(api) => api, + Err(err) => { + let msg = format!("Failed to setup WireGuard API for interface {ifname}: {err:?}"); + debug!("{msg}"); + return Err(ServiceLocationError::InterfaceError(msg)); + } + }; + + wgapi.create_interface()?; + + // Extract DNS configuration if available + let dns_config = Some(location.dns.clone()); + let (dns, search_domains) = dns_borrow(&dns_config); + debug!( + "Configuring interface {ifname} with DNS: {dns:?} and search domains: \ + {search_domains:?}", + ); + debug!("Interface Configuration: {config:?}"); + + wgapi.configure_interface(&config)?; + wgapi.configure_dns(&dns, &search_domains)?; + + self.wgapis.insert(ifname.clone(), wgapi); + + debug!("Interface {ifname} configured successfully."); + Ok(()) + } + + pub(crate) fn connect_to_service_location( + &mut self, + location_data: &SingleServiceLocationData, + ) -> Result<(), ServiceLocationError> { + let instance_id = &location_data.instance_id; + let location_pubkey = &location_data.service_location.pubkey; + debug!( + "Connecting to service location for instance_id: {instance_id}, location_pubkey: \ + {location_pubkey}" + ); + + // Check if already connected to this service location + if self.is_service_location_connected(instance_id, location_pubkey) { + debug!( + "Service location with pubkey {location_pubkey} for instance {instance_id} is \ + already connected, skipping" + ); + return Ok(()); + } + + let location_data = self + .load_service_location(instance_id, location_pubkey)? + .ok_or_else(|| { + ServiceLocationError::LoadError(format!( + "Service location with pubkey {location_pubkey} for instance {instance_id} not \ + found", + )) + })?; + + self.setup_service_location_interface( + &location_data.service_location, + &location_data.private_key, + )?; + self.add_connected_service_location( + &location_data.instance_id, + &location_data.service_location, + )?; + let ifname = get_interface_name(&location_data.service_location.name); + debug!("Successfully connected to service location '{ifname}'"); + + Ok(()) + } + + pub(crate) fn disconnect_service_locations( + &mut self, + mode: Option, + ) -> Result<(), ServiceLocationError> { + debug!("Disconnecting service locations with mode: {mode:?}"); + + for (instance, locations) in &self.connected_service_locations { + for location in locations { + debug!( + "Found connected service location for instance_id: {instance}, \ + location_pubkey: {}", + location.pubkey + ); + if let Some(m) = mode { + let location_mode: ServiceLocationMode = location.mode.try_into()?; + if location_mode != m { + debug!( + "Skipping interface {} due to the service location mode doesn't match the \ + requested mode (expected {m:?}, found {:?})", + location.name, location.mode + ); + continue; + } + } + + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {ifname}"); + if let Some(mut wgapi) = self.wgapis.remove(&ifname) { + if let Err(err) = wgapi.remove_interface() { + error!("Failed to remove interface {ifname}: {err}"); + } else { + debug!("Interface {ifname} removed successfully."); + } + } else { + error!("Failed to find WireGuard API for interface {ifname}"); + } + } + } + + self.remove_connected_service_locations(|_, location| { + if let Some(m) = mode { + let location_mode: ServiceLocationMode = location + .mode + .try_into() + .unwrap_or(ServiceLocationMode::AlwaysOn); + location_mode == m + } else { + true + } + })?; + + debug!("Service locations disconnected."); + + Ok(()) + } + + /// Attempts to connect to all service locations that are not already connected. + /// + /// Returns `Ok(true)` if every location is now connected (either it was already connected or + /// it was successfully connected during this call), and `Ok(false)` if at least one location + /// failed to connect (indicating that a retry may be worthwhile). + pub(crate) fn connect_to_service_locations(&mut self) -> Result { + debug!("Attempting to auto-connect to VPN..."); + + let data = self.load_service_locations()?; + debug!("Loaded {} instance(s) from ServiceLocationApi", data.len()); + + let mut all_connected = true; + + for instance_data in data { + debug!( + "Found service locations for instance ID: {}", + instance_data.instance_id + ); + debug!( + "Instance has {} service location(s)", + instance_data.service_locations.len() + ); + for location in instance_data.service_locations { + debug!("Service Location: {location:?}"); + + if location.mode == ServiceLocationMode::PreLogon as i32 { + if is_user_logged_in() { + debug!( + "Skipping pre-logon service location '{}' because user is logged in", + location.name + ); + continue; + } + debug!( + "Proceeding to connect pre-logon service location '{}' because no user \ + is logged in", + location.name + ); + } + + if self.is_service_location_connected(&instance_data.instance_id, &location.pubkey) + { + debug!( + "Skipping service location '{}' because it's already connected", + location.name + ); + continue; + } + + if let Err(err) = + self.setup_service_location_interface(&location, &instance_data.private_key) + { + warn!( + "Failed to setup service location interface for '{}': {err:?}", + location.name + ); + all_connected = false; + continue; + } + + if let Err(err) = + self.add_connected_service_location(&instance_data.instance_id, &location) + { + debug!( + "Failed to persist connected service location after auto-connect: {err:?}" + ); + } + + debug!( + "Successfully connected to service location '{}'", + location.name + ); + } + } + + debug!("Auto-connect attempt completed"); + + Ok(all_connected) + } + + pub fn save_service_locations( + &self, + service_locations: &[ServiceLocation], + instance_id: &str, + private_key: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Received a request to save {} service location(s) for instance {instance_id}", + service_locations.len(), + ); + + debug!("Service locations to save: {service_locations:?}"); + + create_dir_all(get_shared_directory()?)?; + + let instance_file_path = get_instance_file_path(instance_id)?; + + let service_location_data = ServiceLocationData { + service_locations: service_locations.to_vec(), + instance_id: instance_id.to_string(), + private_key: private_key.to_string(), + }; + + let json = serde_json::to_string_pretty(&service_location_data)?; + + debug!( + "Writing service location data to file: {}", + instance_file_path.display() + ); + + fs::write(&instance_file_path, &json)?; + + if let Some(file_path_str) = instance_file_path.to_str() { + debug!("Setting ACLs on service location file: {file_path_str}"); + if let Err(err) = set_protected_acls(file_path_str) { + warn!( + "Failed to set ACLs on service location file {file_path_str}: {err}. \ + File saved but may have insecure permissions." + ); + } else { + debug!("Successfully set ACLs on service location file"); + } + } else { + warn!("Failed to convert file path to string for ACL setting"); + } + + debug!( + "Service locations saved successfully for instance {instance_id} to {}", + instance_file_path.display() + ); + Ok(()) + } + + fn load_service_locations(&self) -> Result, ServiceLocationError> { + let base_dir = get_shared_directory()?; + let mut all_locations_data = Vec::new(); + + if base_dir.exists() { + for entry in fs::read_dir(base_dir)? { + let entry = entry?; + let file_path = entry.path(); + + if file_path.is_file() && file_path.extension() == Some(OsStr::new("json")) { + match fs::read_to_string(&file_path) { + Ok(data) => match serde_json::from_str::(&data) { + Ok(locations_data) => { + all_locations_data.push(locations_data); + } + Err(err) => { + error!( + "Failed to parse service locations from file {}: {err}", + file_path.display() + ); + } + }, + Err(err) => { + error!( + "Failed to read service locations file {}: {err}", + file_path.display() + ); + } + } + } + } + } + + debug!( + "Loaded service locations data for {} instances", + all_locations_data.len() + ); + Ok(all_locations_data) + } + + fn load_service_location( + &self, + instance_id: &str, + location_pubkey: &str, + ) -> Result, ServiceLocationError> { + debug!("Loading service location for instance {instance_id} and pubkey {location_pubkey}"); + + let instance_file_path = get_instance_file_path(instance_id)?; + + if instance_file_path.exists() { + let data = fs::read_to_string(&instance_file_path)?; + let service_location_data = serde_json::from_str::(&data)?; + + for location in service_location_data.service_locations { + if location.pubkey == location_pubkey { + debug!( + "Successfully loaded service location for instance {instance_id} and \ + pubkey {location_pubkey}" + ); + return Ok(Some(SingleServiceLocationData { + service_location: location, + instance_id: service_location_data.instance_id, + private_key: service_location_data.private_key, + })); + } + } + + debug!( + "No service location found for instance {instance_id} with pubkey {location_pubkey}" + ); + Ok(None) + } else { + debug!("No service location file found for instance {instance_id}"); + Ok(None) + } + } + + pub(crate) fn delete_all_service_locations_for_instance( + &self, + instance_id: &str, + ) -> Result<(), ServiceLocationError> { + debug!("Deleting all service locations for instance {instance_id}"); + + let instance_file_path = get_instance_file_path(instance_id)?; + + if instance_file_path.exists() { + fs::remove_file(&instance_file_path)?; + debug!("Successfully deleted all service locations for instance {instance_id}"); + } else { + debug!("No service location file found for instance {instance_id}"); + } + + Ok(()) + } +} diff --git a/src-tauri/src/active_connections.rs b/src-tauri/src/active_connections.rs index 970a31bd..c42ccb73 100644 --- a/src-tauri/src/active_connections.rs +++ b/src-tauri/src/active_connections.rs @@ -1,97 +1,4 @@ -use std::{collections::HashSet, sync::LazyLock}; - -use tokio::sync::Mutex; - -use crate::{ - database::{ - models::{connection::ActiveConnection, instance::Instance, location::Location, Id}, - DB_POOL, - }, - error::Error, - utils::disconnect_interface, - ConnectionType, +pub use defguard_client_core::connection::active_connections::{ + active_connections, close_all_connections, find_connection, get_connection_id_by_type, + ACTIVE_CONNECTIONS, }; - -pub(crate) static ACTIVE_CONNECTIONS: LazyLock>> = - LazyLock::new(|| Mutex::new(Vec::new())); - -pub(crate) async fn get_connection_id_by_type(connection_type: ConnectionType) -> Vec { - let active_connections = ACTIVE_CONNECTIONS.lock().await; - - let connection_ids = active_connections - .iter() - .filter_map(|con| { - if con.connection_type == connection_type { - Some(con.location_id) - } else { - None - } - }) - .collect(); - - connection_ids -} - -pub async fn close_all_connections() -> Result<(), Error> { - debug!("Closing all active connections"); - let active_connections = ACTIVE_CONNECTIONS.lock().await; - let active_connections_count = active_connections.len(); - debug!("Found {active_connections_count} active connections"); - for connection in active_connections.iter() { - debug!( - "Found active connection with location {}", - connection.location_id - ); - trace!("Connection: {connection:#?}"); - debug!("Removing interface {}", connection.interface_name); - disconnect_interface(connection).await?; - } - if active_connections_count > 0 { - info!("All active connections ({active_connections_count}) have been closed."); - } else { - debug!("There were no active connections to close, nothing to do."); - } - Ok(()) -} - -pub(crate) async fn find_connection( - id: Id, - connection_type: ConnectionType, -) -> Option { - let connections = ACTIVE_CONNECTIONS.lock().await; - trace!( - "Checking for active connection with ID {id}, type {connection_type} in active connections." - ); - - if let Some(connection) = connections - .iter() - .find(|conn| conn.location_id == id && conn.connection_type == connection_type) - { - // 'connection' now contains the first element with the specified id and connection_type - trace!("Found connection: {connection:?}"); - Some(connection.to_owned()) - } else { - debug!( - "Couldn't find connection with ID {id}, type: {connection_type} in active connections." - ); - None - } -} - -/// Returns active connections for a given instance. -pub(crate) async fn active_connections( - instance: &Instance, -) -> Result, Error> { - let locations: HashSet = Location::find_by_instance_id(&*DB_POOL, instance.id, false) - .await? - .iter() - .map(|location| location.id) - .collect(); - Ok(ACTIVE_CONNECTIONS - .lock() - .await - .iter() - .filter(|connection| locations.contains(&connection.location_id)) - .cloned() - .collect()) -} diff --git a/src-tauri/src/apple.rs b/src-tauri/src/apple.rs index b388e6f2..53dbc486 100644 --- a/src-tauri/src/apple.rs +++ b/src-tauri/src/apple.rs @@ -15,7 +15,7 @@ use std::{ }; use block2::RcBlock; -use common::dns_owned; +use defguard_client_common::dns_owned; use defguard_wireguard_rs::{key::Key, net::IpAddrMask, peer::Peer}; use objc2::{ rc::Retained, diff --git a/src-tauri/src/bin/defguard-client.rs b/src-tauri/src/bin/defguard-client.rs index 8925cd6f..76911e78 100644 --- a/src-tauri/src/bin/defguard-client.rs +++ b/src-tauri/src/bin/defguard-client.rs @@ -305,7 +305,11 @@ fn main() { } // Prepare `AppConfig`. - let config = AppConfig::new(app_handle); + let config_dir = app_handle + .path() + .app_data_dir() + .expect("Failed to access app data"); + let config = AppConfig::new(&config_dir); // Setup logging. diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index 9f5d24b8..31003f39 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -1,13 +1,8 @@ use core::fmt; -use std::{ - collections::{HashMap, HashSet}, - env, - str::FromStr, -}; +use std::{collections::HashMap, env, str::FromStr}; -use chrono::{DateTime, Duration, NaiveDateTime, Utc}; +use chrono::{DateTime, Duration, Utc}; use serde::{Deserialize, Serialize}; -use sqlx::{Sqlite, Transaction}; use struct_patch::Patch; use tauri::{AppHandle, Emitter, Manager, State}; @@ -30,22 +25,25 @@ use crate::{ DB_POOL, }, enterprise::{ - self, periodic::config::poll_instance, posture::authorize_posture_session, + self, + periodic::config::{do_update_instance, poll_instance}, + posture::authorize_posture_session, provisioning::ProvisioningConfig, }, error::Error, events::EventKey, + into_location, log_watcher::{ global_log_watcher::{spawn_global_log_watcher_task, stop_global_log_watcher_task}, service_log_watcher::stop_log_watcher_task, }, proto::defguard::client_types::DeviceConfigResponse, + proxy::construct_platform_header, service::proto::defguard::enterprise::posture::v2::DevicePostureData, tray::{configure_tray_icon, reload_tray_menu}, utils::{ - construct_platform_header, disconnect_interface, get_location_interface_details, - get_tunnel_interface_details, get_tunnel_or_location_name, handle_connection_for_location, - handle_connection_for_tunnel, + disconnect_interface, get_location_interface_details, get_tunnel_interface_details, + get_tunnel_or_location_name, handle_connection_for_location, handle_connection_for_tunnel, }, wg_config::parse_wireguard_config, CommonConnection, CommonConnectionInfo, CommonLocationStats, ConnectionType, @@ -172,7 +170,9 @@ pub async fn disconnect( "Emitting the event informing the frontend about the disconnection from \ {connection_type} {name}({location_id})" ); - handle.emit(EventKey::ConnectionChanged.into(), ())?; + handle + .emit(EventKey::ConnectionChanged.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; debug!("Event emitted successfully"); stop_log_watcher_task(&handle, &connection.interface_name)?; reload_tray_menu(&handle).await; @@ -282,7 +282,9 @@ pub async fn disconnect_locations(location_ids: Vec, handle: AppHandle) -> R } if any_disconnected { - handle.emit(EventKey::ConnectionChanged.into(), ())?; + handle + .emit(EventKey::ConnectionChanged.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; reload_tray_menu(&handle).await; configure_tray_icon(&handle).await?; } @@ -308,7 +310,9 @@ async fn maybe_update_instance_config(location_id: Id, handle: &AppHandle) -> Re }; poll_instance(&mut transaction, &mut instance, handle).await?; transaction.commit().await?; - handle.emit(EventKey::InstanceUpdate.into(), ())?; + handle + .emit(EventKey::InstanceUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; Ok(()) } @@ -380,7 +384,7 @@ pub async fn save_device_config( keys.pubkey, instance.name, instance.id ); for dev_config in response.configs { - let new_location = dev_config.into_location(instance.id); + let new_location = into_location(dev_config, instance.id); debug!( "Saving location {} for instance {}({})", new_location.name, instance.name, instance.id @@ -397,7 +401,9 @@ pub async fn save_device_config( let locations = push_service_locations(&instance, keys).await?; - handle.emit(EventKey::InstanceUpdate.into(), ())?; + handle + .emit(EventKey::InstanceUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; let res = SaveDeviceConfigResponse { locations, instance, @@ -433,7 +439,9 @@ async fn push_service_locations( "Adding service location {}({}) for instance {}({}) to be saved to the daemon.", saved_location.name, saved_location.id, instance.name, instance.id, ); - service_locations.push(saved_location.to_service_location()?); + service_locations.push(crate::enterprise::service_locations::to_service_location( + saved_location, + )?); } } @@ -653,7 +661,9 @@ pub async fn update_instance( do_update_instance(&mut transaction, &mut instance, response).await?; transaction.commit().await?; - app_handle.emit(EventKey::InstanceUpdate.into(), ())?; + app_handle + .emit(EventKey::InstanceUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; reload_tray_menu(&app_handle).await; Ok(()) } else { @@ -662,237 +672,6 @@ pub async fn update_instance( } } -/// Returns true if configuration in instance_info differs from current configuration -pub(crate) async fn locations_changed( - transaction: &mut Transaction<'_, Sqlite>, - instance: &Instance, - device_config: &DeviceConfigResponse, -) -> Result { - let db_locations = Location::find_by_instance_id(transaction.as_mut(), instance.id, true) - .await? - .into_iter() - .map(|location| { - let mut new_location = Location::::from(location); - // Ignore `route_all_traffic` flag as Defguard core does not have it. - new_location.route_all_traffic = false; - // Canonicalize mfa_method so a user-set value doesn't falsely trigger a - // config-change detection when the mode hasn't actually changed. - new_location.mfa_method = infer_mfa_method(new_location.location_mfa_mode, None); - new_location - }) - .collect::>(); - let core_locations: HashSet = device_config - .configs - .iter() - .map(|config| config.clone().into_location(instance.id)) - .collect::>(); - - Ok(db_locations != core_locations) -} - -pub(crate) async fn do_update_instance( - transaction: &mut Transaction<'_, Sqlite>, - instance: &mut Instance, - response: DeviceConfigResponse, -) -> Result<(), Error> { - // update instance - debug!("Updating instance {instance}"); - let locations_changed = locations_changed(transaction, instance, &response).await?; - let instance_info = response - .instance - .expect("Missing instance info in device config response"); - instance.name = instance_info.name; - instance.url = instance_info.url; - instance.proxy_url = instance_info.proxy_url; - instance.username = instance_info.username; - // Make sure to update the locations too if we are disabling all traffic - let policy = instance_info.client_traffic_policy.into(); - if instance.client_traffic_policy != policy && policy == ClientTrafficPolicy::DisableAllTraffic - { - debug!("Disabling all traffic for all locations of instance {instance}"); - Location::disable_all_traffic_for_all(transaction.as_mut(), instance.id).await?; - debug!("Disabled all traffic for all locations of instance {instance}"); - } - instance.client_traffic_policy = instance_info.client_traffic_policy.into(); - instance.openid_display_name = instance_info.openid_display_name; - instance.uuid = instance_info.id; - // Token may be empty if it was not issued - // This happens during polling, as core doesn't issue a new token for polling request - if response.token.is_some() { - instance.token = response.token; - debug!("Set polling token for instance {}", instance.name); - } else { - debug!( - "No polling token received for instance {}, not updating", - instance.name - ); - } - instance.save(transaction.as_mut()).await?; - debug!( - "A new base configuration has been applied to instance {instance}, even if nothing changed" - ); - - let mut service_locations = Vec::new(); - - // check if locations have changed - if locations_changed { - // process locations received in response - debug!( - "Updating locations for instance {}({}).", - instance.name, instance.id - ); - // Fetch existing locations for a given instance. - let mut current_locations = - Location::find_by_instance_id(transaction.as_mut(), instance.id, true).await?; - for dev_config in response.configs { - // parse device config - let new_location = dev_config.into_location(instance.id); - - // check if location is already present in current locations - let saved_location = if let Some(position) = current_locations - .iter() - .position(|loc| loc.network_id == new_location.network_id) - { - // remove from list of existing locations - let mut current_location = current_locations.remove(position); - debug!( - "Updating existing location {}({}) for instance {}({}).", - current_location.name, current_location.id, instance.name, instance.id, - ); - // update existing location - current_location.name = new_location.name; - current_location.address = new_location.address; - current_location.pubkey = new_location.pubkey; - current_location.endpoint = new_location.endpoint; - current_location.allowed_ips = new_location.allowed_ips; - current_location.keepalive_interval = new_location.keepalive_interval; - current_location.dns = new_location.dns; - current_location.location_mfa_mode = new_location.location_mfa_mode; - current_location.service_location_mode = new_location.service_location_mode; - // Correct mfa_method to remain consistent with the (possibly updated) mfa_mode. - current_location.mfa_method = infer_mfa_method( - current_location.location_mfa_mode, - current_location.mfa_method, - ); - current_location.posture_check_required = new_location.posture_check_required; - current_location.save(transaction.as_mut()).await?; - info!("Location {current_location} configuration updated for instance {instance}"); - current_location - } else { - // create new location - debug!("Creating new location {new_location} for instance instance {instance}"); - let new_location = new_location.save(transaction.as_mut()).await?; - info!("New location {new_location} created for instance {instance}"); - new_location - }; - - if saved_location.is_service_location() { - debug!( - "Adding service location {}({}) for instance {}({}) to be saved to the daemon.", - saved_location.name, saved_location.id, instance.name, instance.id, - ); - service_locations.push(saved_location.to_service_location()?); - } - } - - // remove locations which were present in current locations - // but no longer found in core response - debug!("Removing locations for instance {instance}"); - for removed_location in current_locations { - removed_location.delete(transaction.as_mut()).await?; - info!( - "Removed location {removed_location} for instance {instance} during instance update" - ); - } - debug!("Finished updating locations for instance {instance}"); - } else { - info!("Locations for instance {instance} didn't change. Not updating them."); - } - - if service_locations.is_empty() { - debug!( - "No service locations for instance {}({}), removing all existing service locations connections if there are any.", - instance.name, instance.id - ); - - #[cfg(not(target_os = "macos"))] - { - let delete_request = DeleteServiceLocationsRequest { - instance_id: instance.uuid.clone(), - }; - DAEMON_CLIENT - .clone() - .delete_service_locations(delete_request) - .await - .map_err(|err| { - error!( - "Error while deleting service locations from the daemon for instance {}({}): {err}", - instance.name, instance.id, - ); - Error::InternalError(err.to_string()) - })?; - debug!( - "Successfully removed all service locations from daemon for instance {}({})", - instance.name, instance.id - ); - } - } else { - debug!( - "Processing {} service location(s) for instance {}({})", - service_locations.len(), - instance.name, - instance.id - ); - - #[cfg(not(target_os = "macos"))] - { - let private_key = WireguardKeys::find_by_instance_id(transaction.as_mut(), instance.id) - .await? - .ok_or(Error::NotFound)? - .prvkey; - - let save_request = SaveServiceLocationsRequest { - service_locations: service_locations.clone(), - instance_id: instance.uuid.clone(), - private_key, - }; - - debug!( - "Sending request to daemon to save {} service location(s) for instance {}({})", - save_request.service_locations.len(), - instance.name, - instance.id - ); - - DAEMON_CLIENT - .clone() - .save_service_locations(save_request) - .await - .map_err(|err| { - error!( - "Error while saving service locations to the daemon for instance {}({}): {err}", - instance.name, instance.id, - ); - Error::InternalError(err.to_string()) - })?; - - info!( - "Successfully saved {} service location(s) to daemon for instance {}({})", - service_locations.len(), - instance.name, - instance.id - ); - - debug!( - "Completed processing all service locations for instance {}({})", - instance.name, instance.id - ); - } - } - - Ok(()) -} - /// If `datetime` is Some, parses the date string, otherwise returns `DateTime` one hour ago. pub(crate) fn parse_timestamp(from: Option) -> Result, Error> { Ok(match from { @@ -901,35 +680,6 @@ pub(crate) fn parse_timestamp(from: Option) -> Result, Err }) } -pub(crate) enum DateTimeAggregation { - Hour, - Second, -} - -impl DateTimeAggregation { - /// Returns database format string for a given aggregation variant. - #[must_use] - pub(crate) fn fstring(&self) -> &'static str { - match self { - Self::Hour => "%Y-%m-%d %H:00:00", - Self::Second => "%Y-%m-%d %H:%M:%S", - } - } -} - -pub(crate) fn get_aggregation(from: NaiveDateTime) -> Result { - // Use hourly aggregation for longer periods - let aggregation = match Utc::now().naive_utc() - from { - duration if duration >= Duration::hours(8) => Ok(DateTimeAggregation::Hour), - duration if duration < Duration::zero() => Err(Error::InternalError(format!( - "Negative duration between dates: now ({}) and {from}", - Utc::now().naive_utc(), - ))), - _ => Ok(DateTimeAggregation::Second), - }?; - Ok(aggregation) -} - #[tauri::command(async)] pub async fn location_stats( location_id: Id, @@ -938,7 +688,7 @@ pub async fn location_stats( ) -> Result>, Error> { trace!("Location stats command received"); let from = parse_timestamp(from)?.naive_utc(); - let aggregation = get_aggregation(from)?; + let aggregation = crate::get_aggregation(from)?; let stats = match connection_type { ConnectionType::Location => { LocationStats::all_by_location_id(&*DB_POOL, location_id, &from, &aggregation, None) @@ -1091,7 +841,9 @@ pub async fn update_location_routing( location.route_all_traffic = route_all_traffic; location.save(&*DB_POOL).await?; debug!("Location routing updated for location {name}(ID: {location_id})"); - handle.emit(EventKey::LocationUpdate.into(), ())?; + handle + .emit(EventKey::LocationUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; Ok(()) } else { error!( @@ -1105,7 +857,9 @@ pub async fn update_location_routing( tunnel.route_all_traffic = route_all_traffic; tunnel.save(&*DB_POOL).await?; info!("Tunnel routing updated for tunnel {location_id}"); - handle.emit(EventKey::LocationUpdate.into(), ())?; + handle + .emit(EventKey::LocationUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; Ok(()) } else { error!("Couldn't update tunnel routing: tunnel with id {location_id} not found."); @@ -1134,7 +888,9 @@ pub async fn set_location_mfa_method( "MFA method updated for location {}(ID: {location_id})", location.name, ); - handle.emit(EventKey::LocationUpdate.into(), ())?; + handle + .emit(EventKey::LocationUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; Ok(()) } else { error!("Location with ID {location_id} not found, cannot set MFA method"); @@ -1183,7 +939,9 @@ pub async fn delete_instance(instance_id: Id, handle: AppHandle) -> Result<(), E configure_tray_icon(&handle).await?; - handle.emit(EventKey::InstanceUpdate.into(), ())?; + handle + .emit(EventKey::InstanceUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; info!("Successfully deleted instance {instance}."); Ok(()) } @@ -1258,7 +1016,9 @@ pub async fn delete_instance(instance_id: Id, handle: AppHandle) -> Result<(), E configure_tray_icon(&handle).await?; - handle.emit(EventKey::InstanceUpdate.into(), ())?; + handle + .emit(EventKey::InstanceUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; info!("Successfully deleted instance {instance}."); Ok(()) } @@ -1279,7 +1039,9 @@ pub async fn update_tunnel(mut tunnel: Tunnel, handle: AppHandle) -> Result< debug!("Received tunnel configuration to update: {tunnel}"); tunnel.save(&*DB_POOL).await?; info!("The tunnel {tunnel} configuration has been updated."); - handle.emit(EventKey::LocationUpdate.into(), ())?; + handle + .emit(EventKey::LocationUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; Ok(()) } @@ -1288,7 +1050,9 @@ pub async fn save_tunnel(tunnel: Tunnel, handle: AppHandle) -> Result<(), debug!("Received tunnel configuration to save: {tunnel}"); let tunnel = tunnel.save(&*DB_POOL).await?; info!("The tunnel {tunnel} configuration has been saved."); - handle.emit(EventKey::LocationUpdate.into(), ())?; + handle + .emit(EventKey::LocationUpdate.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; Ok(()) } @@ -1530,7 +1294,11 @@ pub async fn command_set_app_config( let res = { let mut app_config = app_state.app_config.lock().unwrap(); app_config.apply(config_patch); - app_config.save(&app_handle); + let config_dir = app_handle + .path() + .app_data_dir() + .expect("Failed to access app data"); + app_config.save(&config_dir); app_config.clone() }; info!("Config changed successfully"); diff --git a/src-tauri/src/enterprise/inspector/mod.rs b/src-tauri/src/enterprise/inspector/mod.rs deleted file mode 100644 index 2a7484ce..00000000 --- a/src-tauri/src/enterprise/inspector/mod.rs +++ /dev/null @@ -1,191 +0,0 @@ -#[cfg(target_os = "linux")] -pub(crate) mod linux; -#[cfg(target_os = "macos")] -pub(crate) mod macos; -#[cfg(test)] -mod tests; -#[cfg(windows)] -pub(crate) mod windows; - -use std::{env::consts::OS, error::Error, fmt}; - -use sysinfo::System; - -use crate::{ - service::proto::defguard::enterprise::posture::v2::{ - bool_check, int32_check, string_check, BoolCheck, DevicePostureData, Int32Check, - StringCheck, UnavailableReason, - }, - VERSION, -}; - -impl fmt::Display for UnavailableReason { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Unspecified => f.write_str("unspecified"), - Self::DetectionFailed => f.write_str("detection failed"), - Self::NotApplicable => f.write_str("not applicable on this platform"), - Self::InsufficientPermissions => f.write_str("insufficient permissions"), - } - } -} - -impl Error for UnavailableReason {} - -/// Returns the operating system name. -fn os_name() -> Result { - System::name().ok_or(UnavailableReason::DetectionFailed) -} - -/// Returns the operating system version. -fn os_version() -> Result { - #[cfg(windows)] - { - // Windows can report versions like "11 (26200)"; core expects a parseable major. - System::os_version() - .and_then(|version| version.split_whitespace().next().map(ToString::to_string)) - .ok_or(UnavailableReason::DetectionFailed) - } - - #[cfg(not(windows))] - { - System::os_version().ok_or(UnavailableReason::DetectionFailed) - } -} - -/// Returns the Linux kernel version. -fn linux_kernel_version() -> Result { - #[cfg(target_os = "linux")] - { - System::kernel_version().ok_or(UnavailableReason::DetectionFailed) - } - - #[cfg(not(target_os = "linux"))] - { - Err(UnavailableReason::NotApplicable) - } -} - -/// Returns the disk encryption status, preferably for the system volume. -fn disk_encryption_status() -> Result { - #[cfg(target_os = "macos")] - { - macos::disk_encryption_status() - } - - #[cfg(windows)] - { - windows::disk_encryption_status() - } - - #[cfg(target_os = "linux")] - { - linux::disk_encryption_status() - } -} - -/// Returns the antivirus status. -fn anti_virus_status() -> Result { - #[cfg(windows)] - { - windows::anti_virus_status() - } - - #[cfg(not(windows))] - { - Err(UnavailableReason::NotApplicable) - } -} - -/// Checks whether the computer is part of a domain. -fn part_of_domain() -> Result { - #[cfg(windows)] - { - windows::part_of_domain() - } - - #[cfg(not(windows))] - { - Err(UnavailableReason::NotApplicable) - } -} - -/// Returns the device integrity status. -fn device_integrity() -> Result { - #[cfg(target_os = "macos")] - { - macos::system_integrity_status() - } - - #[cfg(not(target_os = "macos"))] - Err(UnavailableReason::NotApplicable) -} - -/// Returns the number of days since the last installed Windows security update. -fn security_update_age_days() -> Result { - #[cfg(windows)] - { - windows::security_update_age_days() - } - - #[cfg(not(windows))] - { - Err(UnavailableReason::NotApplicable) - } -} - -/// Convert `Result` to `BoolCheck`. -impl From> for BoolCheck { - fn from(value: Result) -> Self { - Self { - result: Some(match value { - Ok(inner) => bool_check::Result::Value(inner), - Err(err) => bool_check::Result::Unavailable(err as i32), - }), - } - } -} - -/// Convert `Result` to `Int32Check`. -impl From> for Int32Check { - fn from(value: Result) -> Self { - Self { - result: Some(match value { - Ok(inner) => int32_check::Result::Value(inner), - Err(err) => int32_check::Result::Unavailable(err as i32), - }), - } - } -} - -/// Convert `Result` to `StringCheck`. -impl From> for StringCheck { - fn from(value: Result) -> Self { - Self { - result: Some(match value { - Ok(inner) => string_check::Result::Value(inner), - Err(err) => string_check::Result::Unavailable(err as i32), - }), - } - } -} - -#[allow(unused)] -impl DevicePostureData { - /// Performs system inspection and returns the results. - #[must_use] - pub fn new() -> Self { - Self { - defguard_client_version: VERSION.to_owned(), - os_type: OS.to_string(), - os_name: Some(StringCheck::from(os_name())), - os_version: Some(StringCheck::from(os_version())), - disk_encryption: Some(BoolCheck::from(disk_encryption_status())), - antivirus_present: Some(BoolCheck::from(anti_virus_status())), - windows_ad_domain_joined: Some(BoolCheck::from(part_of_domain())), - windows_security_update_age_days: Some(Int32Check::from(security_update_age_days())), - linux_kernel_version: Some(StringCheck::from(linux_kernel_version())), - device_integrity: Some(BoolCheck::from(device_integrity())), - } - } -} diff --git a/src-tauri/src/enterprise/mod.rs b/src-tauri/src/enterprise/mod.rs index a76ac9a2..36d30219 100644 --- a/src-tauri/src/enterprise/mod.rs +++ b/src-tauri/src/enterprise/mod.rs @@ -1,6 +1,7 @@ -pub mod inspector; +pub use defguard_client_posture::inspector; +pub use defguard_client_posture::posture; +pub use defguard_client_provisioning::{try_get_provisioning_config, ProvisioningConfig}; pub mod models; pub mod periodic; -pub mod posture; pub mod provisioning; pub mod service_locations; diff --git a/src-tauri/src/enterprise/models/instance.rs b/src-tauri/src/enterprise/models/instance.rs index b0a03437..4651d9df 100644 --- a/src-tauri/src/enterprise/models/instance.rs +++ b/src-tauri/src/enterprise/models/instance.rs @@ -1,28 +1 @@ -use sqlx::SqliteExecutor; - -use crate::{ - database::models::{ - instance::{ClientTrafficPolicy, Instance}, - Id, - }, - error::Error, -}; - -impl Instance { - pub async fn disable_enterprise_features<'e, E>(&mut self, executor: E) -> Result<(), Error> - where - E: SqliteExecutor<'e>, - { - debug!( - "Disabling enterprise features for instance {}({})", - self.name, self.id - ); - self.client_traffic_policy = ClientTrafficPolicy::None; - self.save(executor).await?; - debug!( - "Disabled enterprise features for instance {}({})", - self.name, self.id - ); - Ok(()) - } -} +pub use defguard_client_config_sync::commands::disable_enterprise_features; diff --git a/src-tauri/src/enterprise/periodic/config.rs b/src-tauri/src/enterprise/periodic/config.rs index 22d68031..303c1773 100644 --- a/src-tauri/src/enterprise/periodic/config.rs +++ b/src-tauri/src/enterprise/periodic/config.rs @@ -1,43 +1,41 @@ use std::{ - cmp::Ordering, collections::HashSet, - str::FromStr, sync::{LazyLock, Mutex}, time::Duration, }; -use reqwest::StatusCode; -use serde::Serialize; -use sqlx::{Sqlite, Transaction}; -use tauri::{AppHandle, Emitter, Url}; -use tokio::time::sleep; - -use crate::{ - active_connections::active_connections, - commands::{do_update_instance, locations_changed}, +pub use defguard_client_config_sync::commands::{ + disable_enterprise_features, do_update_instance, locations_changed, +}; +use defguard_client_config_sync::{config_changed, fetch_instance_config}; +use defguard_client_core::{ + connection::active_connections::active_connections, database::{ models::{instance::Instance, Id}, DB_POOL, }, error::Error, events::EventKey, - proto::defguard::client_types::{ - DeviceConfigResponse, InstanceInfoRequest, InstanceInfoResponse, - }, - utils::post_with_headers, - MIN_CORE_VERSION, MIN_PROXY_VERSION, }; +use log::{debug, error, info}; +use sqlx::{Sqlite, Transaction}; +use tauri::{AppHandle, Emitter}; +use tokio::time::sleep; const INTERVAL_SECONDS: Duration = Duration::from_secs(30); -static POLLING_ENDPOINT: &str = "/api/v1/poll"; + +/// Tracks instance IDs for which we already sent a version-mismatch notification, +/// to prevent duplicate notifications in the app's lifetime. +static NOTIFIED_INSTANCES: LazyLock>> = + LazyLock::new(|| Mutex::new(HashSet::new())); /// Periodically retrieves and updates configuration for all [`Instance`]s. /// Updates are only performed if no connections are established to the [`Instance`], -/// otherwise event is emmited and UI message is displayed. +/// otherwise event is emitted and UI message is displayed. pub async fn poll_config(handle: AppHandle) { debug!("Starting the configuration polling loop."); - // Polling starts sooner than app's frontend may load in dev builds, causing events (toasts) to be lost, - // you may want to wait here before starting if you want to debug it. + // Polling starts sooner than app's frontend may load in dev builds, causing events (toasts) + // to be lost; you may want to wait here before starting if you want to debug it. loop { let Ok(mut transaction) = DB_POOL.begin().await else { error!( @@ -49,7 +47,7 @@ pub async fn poll_config(handle: AppHandle) { }; let Ok(mut instances) = Instance::all_with_token(&mut *transaction).await else { error!( - "Failed to retireve instances for config polling, retrying in {}s", + "Failed to retrieve instances for config polling, retrying in {}s", INTERVAL_SECONDS.as_secs() ); let _ = transaction.rollback().await; @@ -68,8 +66,8 @@ pub async fn poll_config(handle: AppHandle) { match err { Error::CoreNotEnterprise => { debug!( - "Tried to contact core for instance {instance} config but it's not \ - enterprise, can't retrieve config" + "Tried to contact core for instance {instance} config but it's \ + not enterprise, can't retrieve config" ); } Error::NoToken => { @@ -118,81 +116,32 @@ pub async fn poll_config(handle: AppHandle) { } } -/// Retrieves configuration for given [`Instance`]. -/// Updates the instance if there aren't any active connections, otherwise displays UI message. +/// Retrieves configuration for a given [`Instance`]. +/// Updates the instance if there aren't any active connections, otherwise emits +/// a ConfigChanged event so the frontend can prompt the user to reconnect. pub async fn poll_instance( transaction: &mut Transaction<'_, Sqlite>, instance: &mut Instance, handle: &AppHandle, ) -> Result<(), Error> { - debug!("Getting config from core for instance {}", instance.name); - // Query proxy api - let request = build_request(instance)?; - let url = Url::from_str(&instance.proxy_url) - .and_then(|url| url.join(POLLING_ENDPOINT)) - .map_err(|_| { - Error::InternalError(format!( - "Can't build polling url: {}/{POLLING_ENDPOINT}", - instance.proxy_url - )) - })?; - let response = post_with_headers(url, &request).await; - let response = response.map_err(|err| { - Error::InternalError(format!( - "HTTP request failed for instance {}({}), url: {}, {err}", - instance.name, instance.id, instance.proxy_url - )) - })?; - debug!( - "Got the following config response for instance {} from core: {response:?}", - instance.name - ); - - check_min_version(&response, instance, handle); - - // Return early if the enterprise features are disabled in the core - if response.status() == StatusCode::PAYMENT_REQUIRED { - debug!( - "Instance {}({}) has enterprise features disabled, checking if this state is reflected \ - on our end.", - instance.name, instance.id - ); - if instance.enterprise_enabled { - info!( - "Instance {}({}) has enterprise features disabled, but we have them enabled, \ - disabling.", - instance.name, instance.id - ); - instance - .disable_enterprise_features(transaction.as_mut()) - .await?; - } else { - debug!( - "Instance {}({}) has enterprise features disabled, and we have them disabled as \ - well, no action needed", - instance.name, instance.id - ); + let fetched = fetch_instance_config(transaction, instance).await?; + + // Emit version-mismatch event if applicable and not already notified + if let Some(payload) = fetched.version_mismatch { + let mut notified_instances = NOTIFIED_INSTANCES.lock().unwrap(); + if notified_instances.insert(instance.id) { + if let Err(err) = handle.emit(EventKey::VersionMismatch.into(), payload) { + error!("Failed to emit version mismatch event to the frontend: {err}"); + // Remove so we can retry next cycle + notified_instances.remove(&instance.id); + } } - return Err(Error::CoreNotEnterprise); } - // Parse the response - debug!( - "Parsing the config response for instance {}.", - instance.name - ); - let response: InstanceInfoResponse = response.json().await.map_err(|err| { - Error::InternalError(format!( - "Failed to parse InstanceInfoResponse for instance {}({}): {err}", - instance.name, instance.id, - )) - })?; - let device_config = response - .device_config - .as_ref() - .ok_or_else(|| Error::InternalError("Device config not present in response".to_string()))?; - debug!("Parsed the config for instance {}", instance.name); - trace!("Parsed config: {device_config:?}"); + let device_config = + fetched.response.device_config.as_ref().ok_or_else(|| { + Error::InternalError("Device config not present in response".to_string()) + })?; // Early return if config didn't change if !config_changed(transaction, instance, device_config).await? { @@ -210,7 +159,6 @@ pub async fn poll_instance( // Config changed. If there are no active connections for this instance, update the database. // Otherwise just display a message to reconnect. - // if active_connections(instance).await?.is_empty() { debug!( "Updating instance {}({}) configuration: {device_config:?}", @@ -235,193 +183,3 @@ pub async fn poll_instance( Ok(()) } - -async fn config_changed( - transaction: &mut Transaction<'_, Sqlite>, - instance: &Instance, - device_config: &DeviceConfigResponse, -) -> Result { - debug!( - "Checking if config and any of the locations changed for instance {}({})", - instance.name, instance.id - ); - let locations_changed = locations_changed(transaction, instance, device_config).await?; - let info_changed = match &device_config.instance { - Some(info) => instance != info, - None => false, - }; - debug!( - "Did the locations change?: {locations_changed}. Did the instance information change?: \ - {info_changed}" - ); - Ok(locations_changed || info_changed) -} - -/// Retrieves token to build InstanceInfoRequest -fn build_request(instance: &Instance) -> Result { - let token = instance.token.as_ref().ok_or_else(|| Error::NoToken)?; - - Ok(InstanceInfoRequest { - token: (*token).clone(), - }) -} - -/// Tracks instance IDs that for which we already sent notification about version mismatches -/// to prevent duplicate notifications in the app's lifetime. -static NOTIFIED_INSTANCES: LazyLock>> = - LazyLock::new(|| Mutex::new(HashSet::new())); - -const CORE_VERSION_HEADER: &str = "defguard-core-version"; -const CORE_CONNECTED_HEADER: &str = "defguard-core-connected"; -const PROXY_VERSION_HEADER: &str = "defguard-component-version"; - -#[derive(Clone, Serialize)] -struct VersionMismatchPayload { - instance_name: String, - instance_id: Id, - core_version: String, - proxy_version: String, - core_required_version: String, - proxy_required_version: String, - core_compatible: bool, - proxy_compatible: bool, -} - -fn check_min_version(response: &reqwest::Response, instance: &Instance, handle: &AppHandle) { - let mut notified_instances = NOTIFIED_INSTANCES.lock().unwrap(); - if notified_instances.contains(&instance.id) { - debug!( - "Instance {}({}) already notified about version mismatch, skipping", - instance.name, instance.id - ); - return; - } - - let detected_core_version: String; - let detected_proxy_version: String; - let defguard_core_connected: Option = response - .headers() - .get(CORE_CONNECTED_HEADER) - .and_then(|v| { - debug!( - "Defguard core connection status header for instance {}({}): {v:?}", - instance.name, instance.id - ); - v.to_str().ok() - }) - .and_then(|s| s.parse().ok()); - - let core_compatible = if let Some(core_version) = response.headers().get(CORE_VERSION_HEADER) { - if let Ok(core_version) = core_version.to_str() { - if let Ok(core_version) = semver::Version::from_str(core_version) { - detected_core_version = core_version.to_string(); - core_version.cmp_precedence(&MIN_CORE_VERSION) != Ordering::Less - } else { - warn!( - "Core version header: invalid semver string in response for instance {}({}): \ - '{core_version}'", - instance.name, instance.id - ); - detected_core_version = core_version.to_string(); - false - } - } else { - warn!( - "Core version header: invalid string in response for instance {}({}): \ - '{core_version:?}'", - instance.name, instance.id - ); - detected_core_version = "unknown".to_string(); - false - } - } else { - warn!( - "Core version header not present in response for instance {}({})", - instance.name, instance.id - ); - detected_core_version = "unknown".to_string(); - false - }; - - let proxy_compatible = if let Some(proxy_version) = response.headers().get(PROXY_VERSION_HEADER) - { - if let Ok(proxy_version) = proxy_version.to_str() { - if let Ok(proxy_version) = semver::Version::from_str(proxy_version) { - detected_proxy_version = proxy_version.to_string(); - proxy_version.cmp_precedence(&MIN_PROXY_VERSION) != Ordering::Less - } else { - warn!( - "Proxy version header not a valid semver string in response for instance {}({}): \ - '{proxy_version}'", - instance.name, instance.id - ); - detected_proxy_version = proxy_version.to_string(); - false - } - } else { - warn!( - "Proxy version header not a valid string in response for instance {}({}): \ - '{proxy_version:?}'", - instance.name, instance.id - ); - detected_proxy_version = "unknown".to_string(); - false - } - } else { - warn!( - "Proxy version header not present in response for instance {}({})", - instance.name, instance.id - ); - detected_proxy_version = "unknown".to_string(); - false - }; - - let should_inform = match defguard_core_connected { - Some(true) => { - debug!( - "Defguard core is connected for instance {}({})", - instance.name, instance.id - ); - true - } - Some(false) => { - info!( - "Defguard core is not connected for instance {}({})", - instance.name, instance.id - ); - false - } - None => { - debug!( - "Defguard core connection status unknown for instance {}({})", - instance.name, instance.id - ); - true - } - }; - - if should_inform && (!core_compatible || !proxy_compatible) { - warn!( - "Instance {} is running incompatible versions: core {detected_core_version}, proxy \ - {detected_proxy_version}. Required versions: core >= {MIN_CORE_VERSION}, proxy >= \ - {MIN_PROXY_VERSION}", - instance.name, - ); - - let payload = VersionMismatchPayload { - instance_name: instance.name.clone(), - instance_id: instance.id, - core_version: detected_core_version, - proxy_version: detected_proxy_version, - core_required_version: MIN_CORE_VERSION.to_string(), - proxy_required_version: MIN_PROXY_VERSION.to_string(), - core_compatible, - proxy_compatible, - }; - if let Err(err) = handle.emit(EventKey::VersionMismatch.into(), payload) { - error!("Failed to emit version mismatch event to the frontend: {err}"); - } else { - notified_instances.insert(instance.id); - } - } -} diff --git a/src-tauri/src/enterprise/provisioning/mod.rs b/src-tauri/src/enterprise/provisioning/mod.rs index 21a35931..b7fe1595 100644 --- a/src-tauri/src/enterprise/provisioning/mod.rs +++ b/src-tauri/src/enterprise/provisioning/mod.rs @@ -1,85 +1,16 @@ -use std::{fmt, fs, path::Path}; +pub use defguard_client_provisioning::{try_get_provisioning_config, ProvisioningConfig}; -use serde::{Deserialize, Serialize}; use tauri::{AppHandle, Manager}; -use crate::database::{models::instance::Instance, DB_POOL}; - -const CONFIG_FILE_NAME: &str = "provisioning.json"; - -#[derive(Clone, Deserialize, Serialize)] -pub struct ProvisioningConfig { - pub enrollment_url: String, - pub enrollment_token: String, -} - -impl fmt::Debug for ProvisioningConfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { - enrollment_url, - enrollment_token: _, - } = self; - - f.debug_struct("ProvisioningConfig") - .field("enrollment_url", enrollment_url) - .field("enrollment_token", &"***") - .finish() - } -} - -impl ProvisioningConfig { - /// Load configuration from a file at `path`. - fn load(path: &Path) -> Option { - // read content to string first to handle Windows encoding issues - let file_content = match fs::read_to_string(path) { - Ok(content) => content, - Err(err) => { - warn!( - "Failed to open provisioning configuration file at {}. Error details: \ - {err}", - path.display() - ); - return None; - } - }; - - // strip Windows BOM manually - let file_content = file_content.trim_start_matches('\u{FEFF}'); - - match serde_json::from_str::(file_content) { - Ok(config) => Some(config), - Err(err) => { - warn!( - "Failed to parse provisioning configuration file at {}. Error details: \ - {err}", - path.display() - ); - None - } - } - } -} - -#[must_use] -pub fn try_get_provisioning_config(app_data_dir: &Path) -> Option { - debug!( - "Trying to find provisioning config in {}", - app_data_dir.display() - ); - - let config_file_path = app_data_dir.join(CONFIG_FILE_NAME); - ProvisioningConfig::load(&config_file_path) -} +use defguard_client_core::database::{models::instance::Instance, DB_POOL}; /// Checks if the client has already been initialized /// and tries to load provisioning config from file if necessary pub async fn handle_client_initialization(app_handle: &AppHandle) -> Option { - // check if client has already been initialized - // we assume that if any instances exist the client has been initialized match Instance::all(&*DB_POOL).await { Ok(instances) => { if instances.is_empty() { - debug!( + log::debug!( "Client has not been initialized yet. Checking if provisioning config exists" ); let data_dir = app_handle @@ -88,14 +19,14 @@ pub async fn handle_client_initialization(app_handle: &AppHandle) -> Option { - info!( + log::info!( "Provisioning config found in {}: {config:?}", data_dir.display() ); return Some(config); } None => { - debug!( + log::debug!( "Provisioning config not found in {}. Proceeding with normal startup.", data_dir.display() ); @@ -104,7 +35,7 @@ pub async fn handle_client_initialization(app_handle: &AppHandle) -> Option { - error!("Failed to verify if the client has already been initialized: {err}"); + log::error!("Failed to verify if the client has already been initialized: {err}"); } } diff --git a/src-tauri/src/enterprise/service_locations/mod.rs b/src-tauri/src/enterprise/service_locations/mod.rs index 199b0c67..23b60ff9 100644 --- a/src-tauri/src/enterprise/service_locations/mod.rs +++ b/src-tauri/src/enterprise/service_locations/mod.rs @@ -1,123 +1,7 @@ -use std::{collections::HashMap, fmt}; - -use defguard_wireguard_rs::{error::WireguardInterfaceError, WGApi}; -use serde::{Deserialize, Serialize}; - -use crate::{ - database::models::{ - location::{Location, ServiceLocationMode}, - Id, - }, - service::proto::defguard::client::v1::ServiceLocation, +pub use defguard_client_service_locations::{ + to_service_location, ServiceLocationData, ServiceLocationError, ServiceLocationManager, + SingleServiceLocationData, }; #[cfg(windows)] -pub mod windows; - -#[derive(Debug, thiserror::Error)] -pub enum ServiceLocationError { - #[error("Error occurred while initializing service location API: {0}")] - InitError(String), - #[error("Failed to load service location storage: {0}")] - LoadError(String), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - DecodeError(#[from] base64::DecodeError), - #[error(transparent)] - WireGuardError(#[from] WireguardInterfaceError), - #[error(transparent)] - AddrParseError(#[from] defguard_wireguard_rs::net::IpAddrParseError), - #[error("WireGuard interface error: {0}")] - InterfaceError(String), - #[error(transparent)] - JsonError(#[from] serde_json::Error), - #[error(transparent)] - ProtoEnumError(#[from] prost::UnknownEnumValue), - #[cfg(windows)] - #[error(transparent)] - WindowsServiceError(#[from] windows_service::Error), -} - -#[allow(dead_code)] -#[derive(Default)] -pub(crate) struct ServiceLocationManager { - // Interface name: WireGuard API instance - wgapis: HashMap, - // Instance ID: Service locations connected under that instance - connected_service_locations: HashMap>, -} - -#[allow(dead_code)] -#[derive(Serialize, Deserialize)] -pub(crate) struct ServiceLocationData { - pub service_locations: Vec, - pub instance_id: String, - pub private_key: String, -} - -#[allow(dead_code)] -pub(crate) struct SingleServiceLocationData { - pub service_location: ServiceLocation, - pub instance_id: String, - pub private_key: String, -} - -impl fmt::Debug for ServiceLocationData { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ServiceLocationData") - .field("service_locations", &self.service_locations) - .field("instance_id", &self.instance_id) - .field("private_key", &"***") - .finish() - } -} - -impl fmt::Debug for SingleServiceLocationData { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SingleServiceLocationData") - .field("service_locations", &self.service_location) - .field("instance_id", &self.instance_id) - .field("private_key", &"***") - .finish() - } -} - -impl Location { - pub fn to_service_location(&self) -> Result { - if !self.is_service_location() { - warn!("Location {self} is not a service location, so it can't be converted to one."); - return Err(crate::error::Error::ConversionError(format!( - "Failed to convert location {self} to a service location as it's either not marked \ - as one or has MFA enabled." - ))); - } - - let mode = match self.service_location_mode { - ServiceLocationMode::Disabled => { - warn!( - "Location {self} has an invalid service location mode, so it can't be converted to \ - one." - ); - return Err(crate::error::Error::ConversionError(format!( - "Location {self} has an invalid service location mode ({:?}), so it can't be \ - converted to one.", - self.service_location_mode - ))); - } - ServiceLocationMode::PreLogon => 0, - ServiceLocationMode::AlwaysOn => 1, - }; - - Ok(ServiceLocation { - name: self.name.clone(), - address: self.address.clone(), - pubkey: self.pubkey.clone(), - endpoint: self.endpoint.clone(), - allowed_ips: self.allowed_ips.clone(), - dns: self.dns.clone().unwrap_or_default(), - keepalive_interval: self.keepalive_interval.try_into().unwrap_or(0), - mode, - }) - } -} +pub use defguard_client_service_locations::windows; diff --git a/src-tauri/src/enterprise/service_locations/windows.rs b/src-tauri/src/enterprise/service_locations/windows.rs index cf8aee72..2cf9b8e7 100644 --- a/src-tauri/src/enterprise/service_locations/windows.rs +++ b/src-tauri/src/enterprise/service_locations/windows.rs @@ -9,7 +9,7 @@ use std::{ time::Duration, }; -use common::{dns_borrow, find_free_tcp_port, get_interface_name}; +use defguard_client_common::{dns_borrow, find_free_tcp_port, get_interface_name}; use defguard_wireguard_rs::{ key::Key, net::IpAddrMask, peer::Peer, InterfaceConfiguration, WireguardInterfaceApi, }; diff --git a/src-tauri/src/events.rs b/src-tauri/src/events.rs index eb5591e2..1d8b4125 100644 --- a/src-tauri/src/events.rs +++ b/src-tauri/src/events.rs @@ -7,41 +7,7 @@ use crate::{ ConnectionType, }; -// Match src/pages/client/types.ts. -#[non_exhaustive] -pub enum EventKey { - ConnectionChanged, - InstanceUpdate, - LocationUpdate, - AppVersionFetch, - ConfigChanged, - DeadConnectionDropped, - DeadConnectionReconnected, - ApplicationConfigChanged, - AddInstance, - MfaTrigger, - VersionMismatch, - UuidMismatch, -} - -impl From for &'static str { - fn from(key: EventKey) -> &'static str { - match key { - EventKey::ConnectionChanged => "connection-changed", - EventKey::InstanceUpdate => "instance-update", - EventKey::LocationUpdate => "location-update", - EventKey::AppVersionFetch => "app-version-fetch", - EventKey::ConfigChanged => "config-changed", - EventKey::DeadConnectionDropped => "dead-connection-dropped", - EventKey::DeadConnectionReconnected => "dead-connection-reconnected", - EventKey::ApplicationConfigChanged => "application-config-changed", - EventKey::AddInstance => "add-instance", - EventKey::MfaTrigger => "mfa-trigger", - EventKey::VersionMismatch => "version-mismatch", - EventKey::UuidMismatch => "uuid-mismatch", - } - } -} +pub use defguard_client_core::events::EventKey; /// Used as payload for [`DEAD_CONNECTION_DROPPED`] event #[derive(Clone, Serialize)] diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index abcf5937..abeb385c 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -1,29 +1,12 @@ // FIXME: actually refactor errors instead #![allow(clippy::result_large_err)] -#[cfg(unix)] -use std::path::Path; -use std::{fmt, path::PathBuf}; -#[cfg(not(windows))] -use std::{ - fs::{set_permissions, Permissions}, - os::unix::fs::PermissionsExt, -}; - -use chrono::NaiveDateTime; -use semver::Version; -use serde::{Deserialize, Serialize}; - -use self::database::models::{Id, NoId}; pub mod active_connections; -pub mod app_config; #[cfg(target_os = "macos")] pub mod apple; pub mod appstate; pub mod commands; -pub mod database; pub mod enterprise; -pub mod error; pub mod events; pub mod log_watcher; pub mod periodic; @@ -31,104 +14,44 @@ pub mod proto; pub mod service; pub mod tray; pub mod utils; -pub mod wg_config; pub mod window_manager; +// Re-export from core so existing imports keep working. +pub use defguard_client_core::version::{ + Version, CLIENT_PLATFORM_HEADER, CLIENT_VERSION_HEADER, LOG_FILENAME, MIN_CORE_VERSION, + MIN_PROXY_VERSION, +}; +pub use defguard_client_core::{ + app_config, + app_data_dir, + connection, + database, + error, + get_aggregation, + into_location, + proxy, + set_perms, + wg_config, + // Shared types + CommonConnection, + CommonConnectionInfo, + CommonLocationStats, + CommonWireguardFields, + ConnectionType, + // DateTime aggregation + DateTimeAggregation, + // Constants + DEFAULT_ROUTE_IPV4, + DEFAULT_ROUTE_IPV6, +}; + pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "-", env!("VERGEN_GIT_SHA")); -pub const MIN_CORE_VERSION: Version = Version::new(1, 6, 0); -pub const MIN_PROXY_VERSION: Version = Version::new(1, 6, 0); -pub const CLIENT_VERSION_HEADER: &str = "defguard-client-version"; -pub const CLIENT_PLATFORM_HEADER: &str = "defguard-client-platform"; pub const PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); -// Must be without ".log" suffix! -pub const LOG_FILENAME: &str = "defguard-client"; -// This must match tauri.bundle.identifier from tauri.conf.json. -const BUNDLE_IDENTIFIER: &str = "net.defguard"; -// Returns the path to the user's data directory. -#[must_use] -pub fn app_data_dir() -> Option { - dirs_next::data_dir().map(|dir| dir.join(BUNDLE_IDENTIFIER)) -} - -/// Ensures path has appropriate permissions set (dg25-28): -/// - 700 for directories -/// - 600 for files -#[cfg(unix)] -pub fn set_perms(path: &Path) { - let perms = if path.is_dir() { 0o700 } else { 0o600 }; - if let Err(err) = set_permissions(path, Permissions::from_mode(perms)) { - warn!( - "Failed to set permissions on path {}: {err}", - path.display() - ); - } -} - -/// Location type used in commands to check if we using tunnel or location -#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Serialize)] -pub enum ConnectionType { - Tunnel, - Location, -} - -impl fmt::Display for ConnectionType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ConnectionType::Tunnel => write!(f, "tunnel"), - ConnectionType::Location => write!(f, "location"), - } - } -} #[macro_use] extern crate log; -/// Common fields for Tunnel and Location -#[derive(Debug, Serialize, Deserialize)] -pub struct CommonWireguardFields { - pub instance_id: Id, - // Native network ID from Defguard Core. - pub network_id: Id, - pub name: String, - pub address: String, - pub pubkey: String, - pub endpoint: String, - pub allowed_ips: String, - pub dns: Option, - pub route_all_traffic: bool, -} - -/// Common fields for Connection and TunnelConnection due to shared command -#[derive(Debug, Serialize, Deserialize)] -pub struct CommonConnection { - pub id: I, - pub location_id: Id, - pub start: NaiveDateTime, - pub end: NaiveDateTime, - pub connection_type: ConnectionType, -} - -// Common fields for LocationStats and TunnelStats due to shared command -#[derive(Debug, Serialize, Deserialize)] -pub struct CommonLocationStats { - pub id: I, - pub location_id: Id, - pub upload: i64, - pub download: i64, - pub last_handshake: i64, - pub collected_at: NaiveDateTime, - pub listen_port: u32, - pub persistent_keepalive_interval: Option, - pub connection_type: ConnectionType, -} - -// Common fields for ConnectionInfo and TunnelConnectionInfo due to shared command -#[derive(Debug, Serialize)] -pub struct CommonConnectionInfo { - pub id: Id, - pub location_id: Id, - pub start: NaiveDateTime, - pub end: NaiveDateTime, - pub upload: Option, - pub download: Option, +/// Converts a tauri emit result into our error type. +pub fn tauri_err_to_app_err(e: tauri::Error) -> defguard_client_core::error::Error { + defguard_client_core::error::Error::Tauri(e.to_string()) } diff --git a/src-tauri/src/proto.rs b/src-tauri/src/proto.rs index b89eb1cf..d1a5bd49 100644 --- a/src-tauri/src/proto.rs +++ b/src-tauri/src/proto.rs @@ -1,59 +1 @@ -use crate::database::models::{ - location::{ - infer_mfa_method, Location, LocationMfaMode as MfaMode, - ServiceLocationMode as SLocationMode, - }, - Id, NoId, -}; - -pub(crate) mod defguard { - pub(crate) use crate::service::proto::defguard::client_types; - - pub(crate) mod proxy { - pub(crate) mod v1 { - tonic::include_proto!("defguard.proxy.v1"); - } - } -} - -impl defguard::client_types::DeviceConfig { - #[must_use] - pub(crate) fn into_location(self, instance_id: Id) -> Location { - let location_mfa_mode = match self.location_mfa_mode { - Some(_location_mfa_mode) => self.location_mfa_mode().into(), - None => { - // handle legacy core response - // DEPRECATED(1.5): superseeded by location_mfa_mode - #[allow(deprecated)] - if self.mfa_enabled { - MfaMode::Internal - } else { - MfaMode::Disabled - } - } - }; - - let service_location_mode = match self.service_location_mode { - Some(_service_location_mode) => self.service_location_mode().into(), - None => SLocationMode::Disabled, // Default to disabled if not set - }; - - Location { - id: NoId, - instance_id, - network_id: self.network_id, - name: self.network_name, - address: self.assigned_ip, // Transforming assigned_ip to address - pubkey: self.pubkey, - endpoint: self.endpoint, - allowed_ips: self.allowed_ips, - dns: self.dns, - route_all_traffic: false, - keepalive_interval: self.keepalive_interval.into(), - location_mfa_mode, - service_location_mode, - mfa_method: infer_mfa_method(location_mfa_mode, None), - posture_check_required: self.posture_check_required.unwrap_or_default(), - } - } -} +pub(crate) use defguard_client_proto::defguard; diff --git a/src-tauri/src/service/client.rs b/src-tauri/src/service/client.rs index 1897e2b1..7385b631 100644 --- a/src-tauri/src/service/client.rs +++ b/src-tauri/src/service/client.rs @@ -1,70 +1 @@ -use std::sync::LazyLock; - -use hyper_util::rt::TokioIo; -#[cfg(windows)] -use tokio::net::windows::named_pipe::ClientOptions; -#[cfg(unix)] -use tokio::net::UnixStream; -use tonic::transport::channel::{Channel, Endpoint}; -#[cfg(unix)] -use tonic::transport::Uri; -use tower::service_fn; -#[cfg(windows)] -use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY; - -use crate::service::proto::defguard::client::v1::desktop_daemon_service_client::DesktopDaemonServiceClient; - -#[cfg(unix)] -use super::daemon::DAEMON_SOCKET_PATH; -#[cfg(windows)] -use super::named_pipe::PIPE_NAME; - -pub(crate) static DAEMON_CLIENT: LazyLock> = - LazyLock::new(|| { - debug!("Setting up gRPC client"); - // URL is ignored since we provide our own connectors for unix socket and windows named pipes. - let endpoint = Endpoint::from_static("http://localhost"); - let channel; - #[cfg(unix)] - { - channel = endpoint.connect_with_connector_lazy(service_fn(|_: Uri| async { - // Connect to a Unix domain socket. - let stream = match UnixStream::connect(DAEMON_SOCKET_PATH).await { - Ok(stream) => stream, - Err(err) if err.kind() == std::io::ErrorKind::PermissionDenied => { - error!( - "Permission denied for UNIX domain socket; please refer to \ - https://docs.defguard.net/support-1/troubleshooting#\ - unix-socket-permission-errors-when-desktop-client-attempts-to-connect-\ - to-vpn-on-linux-machines" - ); - return Err(err); - } - Err(err) => { - error!("Problem connecting to UNIX domain socket: {err}"); - return Err(err); - } - }; - info!("Created unix gRPC client"); - Ok::<_, std::io::Error>(TokioIo::new(stream)) - })); - }; - #[cfg(windows)] - { - channel = endpoint.connect_with_connector_lazy(service_fn(|_| async { - let client = loop { - match ClientOptions::new().open(PIPE_NAME) { - Ok(client) => break client, - Err(err) if err.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), - Err(err) => { - error!("Problem connecting to named pipe: {err}"); - return Err(err); - } - } - }; - info!("Created windows gRPC client"); - Ok::<_, std::io::Error>(TokioIo::new(client)) - })); - } - DesktopDaemonServiceClient::new(channel) - }); +pub(crate) use defguard_client_core::connection::daemon_client::DAEMON_CLIENT; diff --git a/src-tauri/src/service/mod.rs b/src-tauri/src/service/mod.rs index 345d068d..b19dca85 100644 --- a/src-tauri/src/service/mod.rs +++ b/src-tauri/src/service/mod.rs @@ -2,154 +2,4 @@ pub mod client; pub mod config; pub mod proto; - -#[cfg(not(target_os = "macos"))] -pub mod daemon; -#[cfg(windows)] -pub mod named_pipe; pub mod utils; -#[cfg(windows)] -pub mod windows; - -use std::{ - str::FromStr, - time::{Duration, UNIX_EPOCH}, -}; - -use defguard_wireguard_rs::{ - host::Host, key::Key, net::IpAddrMask, peer::Peer, InterfaceConfiguration, -}; - -use crate::service::proto::defguard::client::v1::{ - InterfaceConfig, InterfaceData, Peer as ProtoPeer, -}; - -impl From for InterfaceConfig { - fn from(config: InterfaceConfiguration) -> Self { - Self { - name: config.name, - prvkey: config.prvkey, - address: config - .addresses - .iter() - .map(ToString::to_string) - .collect::>() - .join(","), - port: u32::from(config.port), - peers: config.peers.into_iter().map(Into::into).collect(), - mtu: config.mtu, - } - } -} - -impl From for InterfaceConfiguration { - fn from(config: InterfaceConfig) -> Self { - let addresses = config - .address - .split(',') - .filter_map(|ip| IpAddrMask::from_str(ip.trim()).ok()) - .collect(); - Self { - name: config.name, - prvkey: config.prvkey, - addresses, - port: config.port as u16, - peers: config.peers.into_iter().map(Into::into).collect(), - mtu: config.mtu, - fwmark: None, // TODO: add to config - } - } -} - -impl From for ProtoPeer { - fn from(peer: Peer) -> Self { - Self { - public_key: peer.public_key.to_lower_hex(), - preshared_key: peer.preshared_key.map(|key| key.to_lower_hex()), - protocol_version: peer.protocol_version, - endpoint: peer.endpoint.map(|addr| addr.to_string()), - last_handshake: peer.last_handshake.map(|time| { - time.duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_secs() - }), - tx_bytes: peer.tx_bytes, - rx_bytes: peer.rx_bytes, - persistent_keepalive_interval: peer.persistent_keepalive_interval.map(u32::from), - allowed_ips: peer - .allowed_ips - .into_iter() - .map(|addr| addr.to_string()) - .collect(), - } - } -} - -impl From for Peer { - fn from(peer: ProtoPeer) -> Self { - Self { - public_key: Key::decode(peer.public_key).expect("Failed to parse public key"), - preshared_key: peer - .preshared_key - .map(|key| Key::decode(key).expect("Failed to parse preshared key: {key}")), - protocol_version: peer.protocol_version, - endpoint: peer.endpoint.map(|addr| { - addr.parse() - .expect("Failed to parse endpoint address: {addr}") - }), - last_handshake: peer - .last_handshake - .map(|timestamp| UNIX_EPOCH + Duration::from_secs(timestamp)), - tx_bytes: peer.tx_bytes, - rx_bytes: peer.rx_bytes, - persistent_keepalive_interval: peer - .persistent_keepalive_interval - .and_then(|interval| u16::try_from(interval).ok()), - allowed_ips: peer - .allowed_ips - .into_iter() - .map(|addr| addr.parse().expect("Failed to parse allowed IP: {addr}")) - .collect(), - } - } -} - -impl From for InterfaceData { - fn from(host: Host) -> Self { - Self { - listen_port: u32::from(host.listen_port), - peers: host.peers.into_values().map(Into::into).collect(), - } - } -} - -#[cfg(test)] -mod tests { - use std::time::SystemTime; - - use x25519_dalek::{EphemeralSecret, PublicKey}; - - use super::*; - - #[test] - fn convert_peer() { - let secret = EphemeralSecret::random(); - let key = PublicKey::from(&secret); - let peer_key: Key = key.as_ref().try_into().unwrap(); - let mut base_peer = Peer::new(peer_key); - let addr = IpAddrMask::from_str("10.20.30.2/32").unwrap(); - base_peer.allowed_ips.push(addr); - // Workaround since nanoseconds are lost in conversion. - base_peer.last_handshake = Some(SystemTime::UNIX_EPOCH); - base_peer.protocol_version = Some(3); - base_peer.endpoint = Some("127.0.0.1:8080".parse().unwrap()); - base_peer.tx_bytes = 100; - base_peer.rx_bytes = 200; - - let proto_peer: ProtoPeer = base_peer.clone().into(); - - let converted_peer: Peer = proto_peer.into(); - - assert_eq!(base_peer, converted_peer); - } -} diff --git a/src-tauri/src/service/proto.rs b/src-tauri/src/service/proto.rs index ccc1b046..98d8833b 100644 --- a/src-tauri/src/service/proto.rs +++ b/src-tauri/src/service/proto.rs @@ -1,18 +1 @@ -pub mod defguard { - pub mod enterprise { - pub mod posture { - pub mod v2 { - tonic::include_proto!("defguard.enterprise.posture.v2"); - } - } - } - pub mod client { - pub mod v1 { - tonic::include_proto!("defguard.client.v1"); - } - } - - pub mod client_types { - tonic::include_proto!("defguard.client_types"); - } -} +pub use defguard_client_proto::defguard; diff --git a/src-tauri/src/tray.rs b/src-tauri/src/tray.rs index daa0bf0f..26e1aa0a 100644 --- a/src-tauri/src/tray.rs +++ b/src-tauri/src/tray.rs @@ -52,24 +52,30 @@ fn store_tray_click_position(app: &AppHandle, event: &TrayIconEvent) { /// Generate contents of system tray menu. async fn generate_tray_menu(app: &AppHandle) -> Result, Error> { debug!("Generating tray menu."); - let quit = MenuItem::with_id(app, TRAY_EVENT_QUIT, "Quit", true, None::<&str>)?; - let show = MenuItem::with_id(app, TRAY_EVENT_SHOW, "Show", true, None::<&str>)?; - let hide = MenuItem::with_id(app, TRAY_EVENT_HIDE, "Hide", true, None::<&str>)?; + let quit = MenuItem::with_id(app, TRAY_EVENT_QUIT, "Quit", true, None::<&str>) + .map_err(crate::tauri_err_to_app_err)?; + let show = MenuItem::with_id(app, TRAY_EVENT_SHOW, "Show", true, None::<&str>) + .map_err(crate::tauri_err_to_app_err)?; + let hide = MenuItem::with_id(app, TRAY_EVENT_HIDE, "Hide", true, None::<&str>) + .map_err(crate::tauri_err_to_app_err)?; let subscribe_updates = MenuItem::with_id( app, TRAY_EVENT_UPDATES, "Subscribe for updates", true, None::<&str>, - )?; + ) + .map_err(crate::tauri_err_to_app_err)?; let join_community = MenuItem::with_id( app, TRAY_EVENT_COMMUNITY, "Community support", true, None::<&str>, - )?; - let follow_us = MenuItem::with_id(app, TRAY_EVENT_FOLLOW, "Follow us", true, None::<&str>)?; + ) + .map_err(crate::tauri_err_to_app_err)?; + let follow_us = MenuItem::with_id(app, TRAY_EVENT_FOLLOW, "Follow us", true, None::<&str>) + .map_err(crate::tauri_err_to_app_err)?; let mut menu = MenuBuilder::new(app); debug!("Getting all instances information for the tray menu"); @@ -94,7 +100,8 @@ async fn generate_tray_menu(app: &AppHandle) -> Result, Error location.menu_label(), true, None::<&str>, - )?; + ) + .map_err(crate::tauri_err_to_app_err)?; menu = menu.item(&menu_item); } } else { @@ -114,10 +121,11 @@ async fn generate_tray_menu(app: &AppHandle) -> Result, Error location.menu_label(), true, None::<&str>, - )?; + ) + .map_err(crate::tauri_err_to_app_err)?; instance_menu = instance_menu.item(&menu_item); } - let submenu = instance_menu.build()?; + let submenu = instance_menu.build().map_err(crate::tauri_err_to_app_err)?; menu = menu.item(&submenu); } } @@ -127,14 +135,14 @@ async fn generate_tray_menu(app: &AppHandle) -> Result, Error } } - Ok(menu - .separator() + menu.separator() .items(&[&show, &hide]) .separator() .items(&[&subscribe_updates, &join_community, &follow_us]) .separator() .item(&quit) - .build()?) + .build() + .map_err(crate::tauri_err_to_app_err) } /// Setup system tray. @@ -197,8 +205,8 @@ pub async fn setup_tray(app: &AppHandle) -> Result<(), Error> { } }) .on_menu_event(handle_tray_menu_event) - .build(app)?; - + .build(app) + .map_err(crate::tauri_err_to_app_err)?; debug!("Tray menu successfully generated"); Ok(()) } @@ -296,8 +304,10 @@ pub async fn configure_tray_icon(app_handle: &AppHandle) -> Result<(), Error> { .path() .resolve(&resource_str, BaseDirectory::Resource) { - let icon = Image::from_path(icon_path)?; - tray_icon.set_icon(Some(icon))?; + let icon = Image::from_path(icon_path).map_err(crate::tauri_err_to_app_err)?; + tray_icon + .set_icon(Some(icon)) + .map_err(crate::tauri_err_to_app_err)?; debug!("Tray icon set to {resource_str} successfully."); Ok(()) } else { diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs index 6f6bd3b4..f7c630fc 100644 --- a/src-tauri/src/utils.rs +++ b/src-tauri/src/utils.rs @@ -1,15 +1,11 @@ #[cfg(not(target_os = "macos"))] use std::str::FromStr; -use std::{env, path::Path, process::Command, time::Duration}; +use std::{env, path::Path, process::Command}; -use base64::{prelude::BASE64_STANDARD, Engine}; #[cfg(not(target_os = "macos"))] -use common::{find_free_tcp_port, get_interface_name}; +use defguard_client_common::{find_free_tcp_port, get_interface_name}; #[cfg(not(target_os = "macos"))] use defguard_wireguard_rs::{key::Key, net::IpAddrMask, peer::Peer, InterfaceConfiguration}; -use prost::Message; -use reqwest::{Client, Response}; -use serde::Serialize; use sqlx::query; use tauri::{AppHandle, Emitter, Manager}; #[cfg(not(target_os = "macos"))] @@ -43,8 +39,7 @@ use crate::{ error::Error, events::EventKey, log_watcher::service_log_watcher::spawn_log_watcher_task, - proto::defguard::client_types::ClientPlatformInfo, - ConnectionType, CLIENT_PLATFORM_HEADER, CLIENT_VERSION_HEADER, PKG_VERSION, + ConnectionType, }; #[cfg(not(target_os = "macos"))] use crate::{ @@ -72,59 +67,7 @@ pub(crate) async fn setup_interface( mtu: Option, pool: &DbPool, ) -> Result { - debug!("Setting up interface for location: {location}"); - let interface_name = get_interface_name(name); - - // request interface configuration - debug!("Looking for a free port for interface {interface_name}."); - let Some(port) = find_free_tcp_port() else { - let msg = format!( - "Couldn't find free port during interface {interface_name} setup for location \ - {location}" - ); - error!("{msg}"); - return Err(Error::InternalError(msg)); - }; - debug!("Found free port: {port} for interface {interface_name}."); - - let mut interface_config = location - .interface_configuration(pool, interface_name.clone(), preshared_key, mtu) - .await?; - interface_config.mtu = mtu; - debug!("Creating interface for location {location} with configuration {interface_config:?}"); - let request = CreateInterfaceRequest { - config: Some(interface_config.clone().into()), - dns: location.dns.clone(), - }; - if let Err(error) = DAEMON_CLIENT.clone().create_interface(request).await { - if error.code() == Code::Unavailable { - error!( - "Failed to set up connection for location {location}; background service is \ - unavailable. Make sure the service is running. Error: {error}, Interface \ - configuration: {interface_config:?}" - ); - Err(Error::InternalError( - "Background service is unavailable. Make sure the service is running.".into(), - )) - } else { - error!( - "Failed to send a request to the background service to create an interface for \ - location {location} with the following configuration: {interface_config:?}. \ - Error: {error}" - ); - Err(Error::InternalError(format!( - "Failed to send a request to the background service to create an interface for \ - location {location}. Error: {error}. Check logs for details." - ))) - } - } else { - info!( - "The interface for location {location} has been created successfully, interface \ - name: {}.", - interface_config.name - ); - Ok(interface_name) - } + crate::connection::setup::setup_interface(location, name, preshared_key, mtu, pool).await } #[cfg(target_os = "macos")] @@ -689,7 +632,9 @@ pub(crate) async fn handle_connection_for_location( .await; debug!("Sending event informing the frontend that a new connection has been created."); - handle.emit(EventKey::ConnectionChanged.into(), ())?; + handle + .emit(EventKey::ConnectionChanged.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; debug!("Event informing the frontend that a new connection has been created sent."); // spawn log watcher @@ -725,7 +670,9 @@ pub(crate) async fn handle_connection_for_tunnel( .await; debug!("Sending event informing the frontend that a new connection has been created."); - handle.emit(EventKey::ConnectionChanged.into(), ())?; + handle + .emit(EventKey::ConnectionChanged.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; debug!("Event informing the frontend that a new connection has been created sent."); // spawn log watcher @@ -1016,7 +963,9 @@ async fn check_connection( .await; debug!("Sending event informing the frontend that a new connection has been created."); - app_handle.emit(EventKey::ConnectionChanged.into(), ())?; + app_handle + .emit(EventKey::ConnectionChanged.into(), ()) + .map_err(crate::tauri_err_to_app_err)?; debug!("Event informing the frontend that a new connection has been created sent."); debug!("Spawning service log watcher for {connection_type} {name}..."); @@ -1089,27 +1038,6 @@ pub async fn sync_connections(app_handle: &AppHandle) -> Result<(), Error> { Ok(()) } -#[must_use] -pub(crate) fn construct_platform_header() -> String { - let os = os_info::get(); - - let platform_info = ClientPlatformInfo { - os_family: env::consts::FAMILY.to_string(), - os_type: env::consts::OS.to_string(), - version: os.version().to_string(), - edition: os.edition().map(str::to_string), - codename: os.codename().map(str::to_string), - bitness: Some(os.bitness().to_string()), - architecture: Some(env::consts::ARCH.to_string()), - }; - - debug!("Constructed platform info header: {platform_info:?}"); - - let buffer = platform_info.encode_to_vec(); - - BASE64_STANDARD.encode(buffer) -} - #[must_use] /// Utility function to get all tunnels and locations from the database. #[cfg(target_os = "macos")] @@ -1118,21 +1046,3 @@ pub async fn get_all_tunnels_locations() -> (Vec>, Vec>) let locations = Location::all(&*DB_POOL, false).await.unwrap_or_default(); (tunnels, locations) } - -const HTTP_REQ_TIMEOUT: Duration = Duration::from_secs(5); -pub(crate) async fn post_with_headers( - url: tauri::Url, - data: &T, -) -> Result -where - T: Serialize + ?Sized, -{ - Client::new() - .post(url) - .json(data) - .header(CLIENT_VERSION_HEADER, PKG_VERSION) - .header(CLIENT_PLATFORM_HEADER, construct_platform_header()) - .timeout(HTTP_REQ_TIMEOUT) - .send() - .await -} diff --git a/src-tauri/tauri.linux.conf.json b/src-tauri/tauri.linux.conf.json index 8e0c3208..c2753d76 100644 --- a/src-tauri/tauri.linux.conf.json +++ b/src-tauri/tauri.linux.conf.json @@ -1,10 +1,5 @@ { "productName": "defguard-client", - "build": { - "features": [ - "service" - ] - }, "bundle": { "longDescription": "IMPORTANT: Reboot or Re-login Required\nOn initial install the user is added to the defguard group.\nA reboot or logging out and back in is required for group membership changes to take effect.\nThis is not required on subsequent updates." } diff --git a/src-tauri/tauri.windows.conf.json b/src-tauri/tauri.windows.conf.json index eb86e8c7..826c27e7 100644 --- a/src-tauri/tauri.windows.conf.json +++ b/src-tauri/tauri.windows.conf.json @@ -1,9 +1,4 @@ { - "build": { - "features": [ - "service" - ] - }, "bundle": { "targets": [ "msi"