File gmsa-kinit.cpp of Package gmsa-kinit

// SPDX-FileCopyrightText: 2022 Amazon.com, Inc. or its affiliates
// SPDX-FileCopyrightText> 2025 Enno Tensing <tenno+gmsa-kinit@suij.in>
// SPDX-License-Identifier: Apache-2.0
#include <cstdio>
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <sstream>
#include <utility>
#include <openssl/crypto.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <glib.h>
#include <getopt.h>

typedef struct blob_t_ {
	uint16_t version;
	uint16_t reserved;
	uint32_t length;
	uint16_t current_password_offset;
	uint16_t previous_password_offset;
	uint16_t query_password_interval_offset;
	uint16_t unchanged_password_interval_offset;
	/* TBD: Fix this, remaining buf size is variable */
#define BLOB_REMAINING_BUF_SIZE 1024
	/* TBD: Get from parsed blob */
#define GMSA_PASSWORD_SIZE 256
	uint8_t current_password[1024];
	/* TBD:: Add remaining fields here */
} blob_t;

class krb_ticket_info_t {
    public:
	std::string krb_file_path;
	std::string service_account_name;
	std::string domain_name;
	std::string credspec_info;
	std::string distinguished_name;
	std::string credential_arn;
	std::string base_dn;
	std::string ldap_host;
};

#define DECODE_PROG "/usr/sbin/utf16toutf8"

#define GET_ARG(a, b)	\
do {				\
	if (a) {		\
		b = a;		\
	} else {		\
		print_help();		\
	}			\
} while(0)

static std::vector<std::string> split_string(std::string input_string,
					     char delimiter)
{
	std::vector<std::string> results;
	std::istringstream input_string_stream(input_string);
	std::string token;
	while (std::getline(input_string_stream, token, delimiter)) {
		results.push_back(token);
		if (delimiter == '=') {
			while (std::getline(input_string_stream, token)) {
				results.push_back(token);
			}
			break;
		}
	}
	return results;
}

static uint8_t *base64_decode(const std::string &password,
			      gsize *base64_decode_len)
{
	if (base64_decode_len == nullptr || password.empty()) {
		return nullptr;
	}

	*base64_decode_len = 0;
	guchar *result = g_base64_decode(password.c_str(), base64_decode_len);
	if (result == nullptr || *base64_decode_len <= 0) {
		return nullptr;
	}

	void *secure_mem = OPENSSL_malloc(*base64_decode_len);
	if (secure_mem == nullptr) {
		g_free(result);
		return nullptr;
	}

	memcpy(secure_mem, result, *base64_decode_len);

	memset(result, 0, *base64_decode_len);
	g_free(result);

	/**
         * secure_mem must be freed later
         */
	return (uint8_t *)secure_mem;
}

static std::pair<size_t, void *> find_password(std::string ldap_search_result)
{
	size_t base64_decode_len = 0;
	std::vector<std::string> results;

	std::string password = std::string("msDS-ManagedPassword::");
	results = split_string(ldap_search_result, '#');
	bool password_found = false;
	for (auto &result : results) {
		auto found = result.find(password);
		if (found != std::string::npos) {
			found += password.length();
			password = result.substr(found + 1, result.length());
			// std::cerr << "Password = " << password << std::endl;
			password_found = true;
			break;
		}
	}

	uint8_t *blob_base64_decoded = nullptr;
	if (password_found) {
		blob_base64_decoded =
			base64_decode(password, &base64_decode_len);
		if (blob_base64_decoded == nullptr) {
			std::cerr << "ERROR: base64 buffer is null"
				  << std::endl;
			return std::make_pair(0, nullptr);
		}
	}

	return std::make_pair(base64_decode_len, blob_base64_decoded);
}

static std::pair<int, std::string> exec_shell_cmd(std::string cmd)
{
	std::string output;
	char line[80];

	char *cmd_str = (char *)calloc(cmd.length() + 1, sizeof(char));
	strncpy(cmd_str, cmd.c_str(), cmd.length());

	FILE *pFile = popen(cmd_str, "r");
	if (pFile == nullptr) {
		std::pair<int, std::string> result =
			std::make_pair(-1, std::string(""));
		free(cmd_str);
		return result;
	}

	while (fgets(line, sizeof(line), pFile) != nullptr) {
		output += std::string(line);
	}

	int error_code = pclose(pFile);

	std::pair<int, std::string> result = std::make_pair(error_code, output);
	free(cmd_str);

	return result;
}

static std::pair<int, std::string>
execute_ldapsearch(std::string gmsa_account_name, std::string base_dn,
		   std::string domain, std::string ldap_host, std::string search_string)
{
	std::string fqdn;
	if (ldap_host.empty())
		fqdn = domain;
	else
		fqdn = ldap_host;
	std::string cmd;
	std::pair<int, std::string> ldap_search_result;
	cmd = std::string(
		      "ldapsearch -o ldif_wrap=no -LLL -Y GSSAPI -H ldap://") +
	      fqdn;
	cmd += std::string(" -b '") + base_dn + std::string("' ") + search_string;
	if (!ldap_host.empty()) {
		/*
		 * -N: Do not use reverse DNS to canonicalize SASL host name.
		 * With this flag, ldapsearch uses the IP address directly for
		 * identification purposes, rather than trying to resolve it to a hostname.
		 * 
		 * However if ldap_host is not set, we query against an entire domain and not just
		 * a single DC, so we have to resolve to hostname.
		*/
		cmd += std::string(" -N");
	}

	std::cerr << "INFO: " << cmd << std::endl;
	std::cerr << cmd << std::endl;

	for (int i = 0; i < 2; i++) {
		ldap_search_result = exec_shell_cmd(cmd);
		cmd += ldap_search_result.second;
		ldap_search_result.second = cmd;
		// Add retry, ldapsearch seems to fail and then succeed on retry
		if (ldap_search_result.first != 0) {
			std::string err_msg =
				std::string(
					"ERROR: ldapsearch failed with FQDN = ") +
				fqdn;
			std::cerr << err_msg << std::endl;
			err_msg = std::string(
				"ERROR: ldapsearch failed to get gMSA credentials: " +
				ldap_search_result.second);
			std::cerr << err_msg << std::endl;
			err_msg = ldap_search_result.second + err_msg;
			ldap_search_result.second = err_msg;
		} else {
			std::string err_msg =
				"INFO: ldapsearch succeeded with FQDN = ";
			std::cerr << err_msg << fqdn << std::endl;
			ldap_search_result.first = 0;
			ldap_search_result.second =
				ldap_search_result.second + err_msg;
			break;
		}
	}

	return ldap_search_result;
}

static std::pair<int, std::string>
find_dn(std::string gmsa_account_name, std::string base_dn, std::string domain, std::string ldap_host)
{
	/**
         *  ldapsearch  -H ldap://ip-xxxxxxxx.activedirectory1.com
         *    -b 'DC=ActiveDirectory1,DC=com' -s sub '(CN=WebApp01)'  distinguishedName | grep
         * "distinguishedName:"
         */
	std::string distinguished_name;
	std::string search_string =
		" -s sub '(CN=" + gmsa_account_name + ")' distinguishedName";
	std::pair<int, std::string> ldap_search_result = execute_ldapsearch(
		gmsa_account_name, base_dn, domain, ldap_host, search_string);
	if (ldap_search_result.first == 0 &&
	    !ldap_search_result.second.empty()) {
		std::size_t start_pos =
			ldap_search_result.second.find("distinguishedName:");
		if (start_pos != std::string::npos) {
			// distinguishedName:
			// CN=WebApp01,OU=MYOU,OU=Users,OU=ActiveDirectory,DC=ActiveDirectory1,DC=com
			distinguished_name = "distinguishedName: ";
			start_pos += distinguished_name.length();

			distinguished_name =
				ldap_search_result.second.substr(start_pos);
			std::size_t end_pos =
				distinguished_name.find_first_of("\n");
			if (end_pos != std::string::npos) {
				distinguished_name =
					distinguished_name.substr(0, end_pos);
			} else {
				distinguished_name = "";
				return std::make_pair(-1, distinguished_name);
			}
		}
	} else {
		distinguished_name = "";
		return std::make_pair(-1, distinguished_name);
	}
	return std::make_pair(0, distinguished_name);
}

std::pair<int, std::string> get_gmsa_krb_ticket(krb_ticket_info_t *krb_ticket)
{
	std::vector<std::string> results;
	std::string gmsa_account_name = "";
	std::string distinguished_name = "";
	std::string domain_name = "";
	std::string krb_cc_name = "";
	std::string base_dn = "";
	std::string ldap_host = "";

	if (krb_ticket != NULL) {
		gmsa_account_name = krb_ticket->service_account_name;
		distinguished_name = krb_ticket->distinguished_name;
		domain_name = krb_ticket->domain_name;
		krb_cc_name = krb_ticket->krb_file_path;
		base_dn = krb_ticket->base_dn;
		ldap_host = krb_ticket->ldap_host;
	}

	if (domain_name.empty()) {
		std::string log_message = "ERROR: No LDAP Domain set!";

		return std::make_pair(-1, log_message);
	}

	if (gmsa_account_name.empty()) {
		std::string log_message = "ERROR: No gMSA Name set!";

		return std::make_pair(-1, log_message);
	}

	std::pair<int, std::string> ldap_search_result;

	if (distinguished_name.empty()) {
		std::pair<int, std::string> distinguished_name_result =
			find_dn(gmsa_account_name, base_dn, domain_name, ldap_host);
		if (distinguished_name_result.first == 0 &&
		    !distinguished_name_result.second.empty()) {
			distinguished_name = distinguished_name_result.second;
		}
		std::string log_str = "Found dn = " + distinguished_name;
	}

	krb_ticket->distinguished_name = distinguished_name;
	// Then find the password
	std::string search_string = std::string(
		" -s sub  '(objectClass=msDs-GroupManagedServiceAccount)' msDS-ManagedPassword");
	ldap_search_result = execute_ldapsearch(gmsa_account_name,
						distinguished_name, domain_name,
						ldap_host, search_string);
	if (ldap_search_result.first == 0) {
		std::size_t pos =
			ldap_search_result.second.find("msDS-ManagedPassword:");
		if (pos != std::string::npos) {
			std::string log_str =
				ldap_search_result.second.substr(0, pos);
			log_str = "ldapsearch successful with FQDN = " +
				  domain_name + ", cmd = " + log_str + "," +
				  "search_string = " + search_string;
			std::cerr << log_str << std::endl;
		}
	} else {
		std::string log_str =
			"ldapsearch failed with FQDN = " + domain_name + " " +
			ldap_search_result.second.c_str() + " " + search_string;
		std::cerr << log_str << std::endl;
	}

	if (ldap_search_result.first !=
	    0) // ldapsearch did not work in any FQDN
	{
		return std::make_pair(-1, std::string(""));
	}

	std::pair<size_t, void *> password_found_result =
		find_password(ldap_search_result.second);
	OPENSSL_cleanse((void *)ldap_search_result.second.c_str(),
			strlen(ldap_search_result.second.c_str()));

	if (password_found_result.first == 0 ||
	    password_found_result.second == nullptr) {
		std::string log_str = "ERROR: Password not found";
		std::cerr << log_str << std::endl;
		return std::make_pair(-1, log_str);
	}

	blob_t *blob = ((blob_t *)password_found_result.second);
	auto *blob_password = (uint8_t *)blob->current_password;

	std::transform(domain_name.begin(), domain_name.end(),
		       domain_name.begin(),
		       [](unsigned char c) { return std::toupper(c); });
	std::string default_principal =
		"'" + gmsa_account_name + "$'" + "@" + domain_name;

	/* Pipe password to the utf16 decoder and kinit */
	std::string kinit_cmd = std::string("mono ") +
				std::string(DECODE_PROG) +
				std::string(" | kinit ") + std::string(" -c ") +
				krb_cc_name + " -V " + default_principal;
	std::cerr << "INFO:" << kinit_cmd << std::endl;
	FILE *fp = popen(kinit_cmd.c_str(), "w");
	if (fp == nullptr) {
		perror("kinit failed");
		OPENSSL_cleanse(password_found_result.second,
				password_found_result.first);
		OPENSSL_free(password_found_result.second);
		std::string log_message = "ERROR: " + std::string(__func__) +
					  " : " + std::to_string(__LINE__) +
					  " kinit failed";
		std::cerr << "ERROR: kinit failed" << std::endl;
		return std::make_pair(-1, std::string("kinit failed"));
	}
	fwrite(blob_password, 1, GMSA_PASSWORD_SIZE, fp);
	int error_code = pclose(fp);

	// kinit output
	std::string log_str =
		"INFO: kinit return value = " + std::to_string(error_code);
	std::cerr << log_str << std::endl;

	OPENSSL_cleanse(password_found_result.second,
			password_found_result.first);

	return std::make_pair(error_code, "");
}

void print_help(void)
{
	std::cerr << "gmsa-kinit -d LDAP_DOMAIN -g GMSA_ACCOUNT -b BASE_DN -k KERBEROS_FILE [-d GMSA_DN] [-H LDAP_DC]" << std::endl;
}

int main(int argc, char **argv)
{
	int c;
	int digit_optind;

	krb_ticket_info_t kt;

	static struct option long_options[] = {
		{ "domain", required_argument, 0, 'd' },
		{ "gmsa-account", required_argument, 0, 'g' },
		{ "base-dn", required_argument, 0, 'b' },
		{ "krb-file", required_argument, 0, 'k' },
		{ "gmsa-dn", required_argument, 0, 'D' },
		{ "host", required_argument, 0, 'H' },
		{ "help", no_argument, 0, 'h' },
		{ 0, 0, 0, 0 }
	};

	while (1) {
		int this_optind = optind ? optind : 1;
		int option_index = 0;

		c = getopt_long(argc, argv, "d:g:b:k:D:H:h", long_options,
				&option_index);

		if (c == -1)
			break;

		switch (c) {
		case 'd':
			GET_ARG(optarg, kt.domain_name);
			break;
		case 'g':
			GET_ARG(optarg, kt.service_account_name);
			break;
		case 'b':
			GET_ARG(optarg, kt.base_dn);
			break;
		case 'k':
			GET_ARG(optarg, kt.krb_file_path);
			break;
		case 'D':
			GET_ARG(optarg, kt.distinguished_name);
			break;
		case 'H':
			GET_ARG(optarg, kt.ldap_host);
			break;
		case 'h':
			print_help();
			return 0;
		default:
			break;
		}
	}

	if (kt.domain_name.empty() || kt.service_account_name.empty() || kt.krb_file_path.empty() || kt.base_dn.empty())
		return 1;
	std::pair ret = get_gmsa_krb_ticket(&kt);

	if (!ret.second.empty())
		std::cerr << ret.second << std::endl;

	return ret.first;
}
openSUSE Build Service is sponsored by