/* libp11, a simple layer on top of PKCS#11 API
 * Copyright (C) 2005 Olaf Kirch <okir@lst.de>
 * Copyright (C) 2016-2025 Michał Trojnara <Michal.Trojnara@stunnel.org>
 * Copyright © 2025 Mobi - Com Polska Sp. z o.o.
 *
 *  This library is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Lesser General Public
 *  License as published by the Free Software Foundation; either
 *  version 2.1 of the License, or (at your option) any later version.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  License along with this library; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
 */

#include "libp11-int.h"
#include <string.h>
#include <openssl/ui.h>
#include <openssl/bn.h>

/* The maximum length of PIN */
#define MAX_PIN_LENGTH   256

#if OPENSSL_VERSION_NUMBER < 0x40000000L
# if OPENSSL_VERSION_NUMBER >= 0x30000000L
static int pkey_ex_index = 0;
# endif /* OPENSSL_VERSION_NUMBER >= 0x30000000L*/

#if OPENSSL_VERSION_NUMBER < 0x10101000L || defined(LIBRESSL_VERSION_NUMBER)
static EVP_PKEY_METHOD *orig_method_rsa = NULL;
#else
static const EVP_PKEY_METHOD *orig_method_rsa = NULL;
#endif /* OPENSSL_VERSION_NUMBER < 0x10101000L || defined(LIBRESSL_VERSION_NUMBER) */

static int (*orig_pkey_rsa_sign_init) (EVP_PKEY_CTX *ctx);
static int (*orig_pkey_rsa_sign) (EVP_PKEY_CTX *ctx,
	unsigned char *sig, size_t *siglen,
	const unsigned char *tbs, size_t tbslen);
static int (*orig_pkey_rsa_decrypt_init) (EVP_PKEY_CTX *ctx);
static int (*orig_pkey_rsa_decrypt) (EVP_PKEY_CTX *ctx,
	unsigned char *out, size_t *outlen,
	const unsigned char *in, size_t inlen);
#endif /* OPENSSL_VERSION_NUMBER < 0x40000000L */

#if OPENSSL_VERSION_NUMBER >= 0x30000000L
# ifndef OPENSSL_NO_EC
/* DER OIDs */
static const unsigned char OID_ED25519[] = { 0x06, 0x03, 0x2B, 0x65, 0x70 };
static const unsigned char OID_ED448[]   = { 0x06, 0x03, 0x2B, 0x65, 0x71 };

/* PrintableString forms used by some tokens (e.g. SoftHSM) */
static const unsigned char STR_ED25519[] = {
    0x13, 0x0C, /* tag + length */
    'e','d','w','a','r','d','s','2','5','5','1','9'
};
static const unsigned char STR_ED448[] = {
    0x13, 0x0A, /* tag + length */
    'e','d','w','a','r','d','s','4','4','8'
};
# endif /* OPENSSL_NO_EC */
#endif /* OPENSSL_VERSION_NUMBER >= 0x30000000L */

#if OPENSSL_VERSION_NUMBER < 0x100020d0L || defined(LIBRESSL_VERSION_NUMBER)
struct evp_pkey_method_st {
    int pkey_id;
    int flags;
    int (*init) (EVP_PKEY_CTX *ctx);
    int (*copy) (EVP_PKEY_CTX *dst, EVP_PKEY_CTX *src);
    void (*cleanup) (EVP_PKEY_CTX *ctx);
    int (*paramgen_init) (EVP_PKEY_CTX *ctx);
    int (*paramgen) (EVP_PKEY_CTX *ctx, EVP_PKEY *pkey);
    int (*keygen_init) (EVP_PKEY_CTX *ctx);
    int (*keygen) (EVP_PKEY_CTX *ctx, EVP_PKEY *pkey);
    int (*sign_init) (EVP_PKEY_CTX *ctx);
    int (*sign) (EVP_PKEY_CTX *ctx, unsigned char *sig, size_t *siglen,
                 const unsigned char *tbs, size_t tbslen);
    int (*verify_init) (EVP_PKEY_CTX *ctx);
    int (*verify) (EVP_PKEY_CTX *ctx,
                   const unsigned char *sig, size_t siglen,
                   const unsigned char *tbs, size_t tbslen);
    int (*verify_recover_init) (EVP_PKEY_CTX *ctx);
    int (*verify_recover) (EVP_PKEY_CTX *ctx,
                           unsigned char *rout, size_t *routlen,
                           const unsigned char *sig, size_t siglen);
    int (*signctx_init) (EVP_PKEY_CTX *ctx, EVP_MD_CTX *mctx);
    int (*signctx) (EVP_PKEY_CTX *ctx, unsigned char *sig, size_t *siglen,
                    EVP_MD_CTX *mctx);
    int (*verifyctx_init) (EVP_PKEY_CTX *ctx, EVP_MD_CTX *mctx);
    int (*verifyctx) (EVP_PKEY_CTX *ctx, const unsigned char *sig, int siglen,
                      EVP_MD_CTX *mctx);
    int (*encrypt_init) (EVP_PKEY_CTX *ctx);
    int (*encrypt) (EVP_PKEY_CTX *ctx, unsigned char *out, size_t *outlen,
                    const unsigned char *in, size_t inlen);
    int (*decrypt_init) (EVP_PKEY_CTX *ctx);
    int (*decrypt) (EVP_PKEY_CTX *ctx, unsigned char *out, size_t *outlen,
                    const unsigned char *in, size_t inlen);
    int (*derive_init) (EVP_PKEY_CTX *ctx);
    int (*derive) (EVP_PKEY_CTX *ctx, unsigned char *key, size_t *keylen);
    int (*ctrl) (EVP_PKEY_CTX *ctx, int type, int p1, void *p2);
    int (*ctrl_str) (EVP_PKEY_CTX *ctx, const char *type, const char *value);
}; /* EVP_PKEY_METHOD */
#endif

static int pkcs11_find_keys(PKCS11_SLOT_private *, CK_SESSION_HANDLE, unsigned int,
	PKCS11_TEMPLATE *);
static int pkcs11_init_key(PKCS11_SLOT_private *, CK_SESSION_HANDLE session,
	CK_OBJECT_HANDLE o, CK_OBJECT_CLASS type, PKCS11_KEY **);
static int pkcs11_init_keygen(PKCS11_SLOT_private *, CK_SESSION_HANDLE *);
static int pkcs11_next_key(PKCS11_CTX_private *ctx, PKCS11_SLOT_private *,
	CK_SESSION_HANDLE session, CK_OBJECT_CLASS type);
static int pkcs11_store_key(PKCS11_SLOT_private *, EVP_PKEY *, CK_OBJECT_CLASS,
	char *, unsigned char *, size_t, PKCS11_KEY **);
static void pkcs11_common_pubkey_attr(PKCS11_TEMPLATE *, const char *,
	const unsigned char *, size_t);
static void pkcs11_common_privkey_attr(PKCS11_TEMPLATE *, const char *,
	const unsigned char *, size_t, const PKCS11_params *);

/* Helper to acquire object handle from given template */
static CK_OBJECT_HANDLE pkcs11_handle_from_template(PKCS11_SLOT_private *slot,
	CK_SESSION_HANDLE session, PKCS11_TEMPLATE *tmpl)
{
	PKCS11_CTX_private *ctx = slot->ctx;
	CK_OBJECT_HANDLE object;
	CK_ULONG count;
	CK_RV rv;

	rv = CRYPTOKI_call(ctx,
		C_FindObjectsInit(session, tmpl->attrs, tmpl->nattr));
	if (rv == CKR_OK) {
		rv = CRYPTOKI_call(ctx,
			C_FindObjects(session, &object, 1, &count));
		CRYPTOKI_call(ctx, C_FindObjectsFinal(session));
	}
	pkcs11_zap_attrs(tmpl);

	if (rv == CKR_OK && count == 1)
		return object;

	return CK_INVALID_HANDLE;
}

/* Get object from a handle */
PKCS11_OBJECT_private *pkcs11_object_from_handle(PKCS11_SLOT_private *slot,
		CK_SESSION_HANDLE session, CK_OBJECT_HANDLE object)
{
	PKCS11_CTX_private *ctx = slot->ctx;
	PKCS11_OBJECT_private *obj;
	PKCS11_OBJECT_ops *ops = NULL;
	CK_OBJECT_CLASS object_class = -1;
	CK_KEY_TYPE key_type = -1;
	CK_CERTIFICATE_TYPE cert_type = -1;
	size_t size;
	unsigned char *data;

	if (pkcs11_getattr_val(ctx, session, object, CKA_CLASS,
			(CK_BYTE *) &object_class, sizeof(object_class))) {
		pkcs11_log(ctx, LOG_DEBUG, "Missing CKA_CLASS attribute\n");
		return NULL;
	}

	switch (object_class) {
	case CKO_PUBLIC_KEY:
	case CKO_PRIVATE_KEY:
		if (pkcs11_getattr_val(ctx, session, object, CKA_KEY_TYPE,
				(CK_BYTE *)&key_type, sizeof(key_type))) {
			pkcs11_log(ctx, LOG_DEBUG, "Missing CKA_KEY_TYPE attribute\n");
			return NULL;
		}
		switch (key_type) {
		case CKK_RSA:
			ops = &pkcs11_rsa_ops;
			break;
#ifndef OPENSSL_NO_EC
		case CKK_EC:
			ops = &pkcs11_ec_ops;
			break;
# if OPENSSL_VERSION_NUMBER >= 0x30000000L
		case CKK_EC_EDWARDS:
			/* Read the CKA_EC_PARAMS to distinguish Ed25519 vs Ed448 */
			if (pkcs11_getattr_alloc(ctx, session, object,
				CKA_EC_PARAMS, &data, &size)) {
				pkcs11_log(ctx, LOG_DEBUG, "Missing CKA_EC_PARAMS attribute\n");
				return NULL;
			}
			if ((size == sizeof(OID_ED25519) &&
				!memcmp(data, OID_ED25519, sizeof(OID_ED25519))) ||
				(size == sizeof(STR_ED25519) &&
				!memcmp(data, STR_ED25519, sizeof(STR_ED25519)))) {
				ops = &pkcs11_ed25519_ops;
			} else if ((size == sizeof(OID_ED448) &&
				!memcmp(data, OID_ED448, sizeof(OID_ED448))) ||
				(size == sizeof(STR_ED448) &&
				!memcmp(data, STR_ED448, sizeof(STR_ED448)))) {
				ops = &pkcs11_ed448_ops;
			} else {
				pkcs11_log(ctx, LOG_DEBUG, "Unsupported EdDSA OID\n");
				OPENSSL_free(data);
				return NULL;
			}
			OPENSSL_free(data);
			break;
# endif /* OPENSSL_VERSION_NUMBER >= 0x30000000L */
#endif /* OPENSSL_NO_EC */
		default:
			/* Ignore any keys we don't understand */
			pkcs11_log(ctx, LOG_DEBUG,
				"Unsupported CKA_KEY_TYPE attribute value: %lu\n",
				(unsigned long)key_type);
			return NULL;
		}
		break;
	case CKO_CERTIFICATE:
		if (pkcs11_getattr_val(ctx, session, object, CKA_CERTIFICATE_TYPE,
				(CK_BYTE *)&cert_type, sizeof(cert_type))) {
			pkcs11_log(ctx, LOG_DEBUG, "Missing CKA_CERTIFICATE_TYPE attribute\n");
			return NULL;
		}
		/* Ignore unknown certificate types */
		if (cert_type != CKC_X_509) {
			pkcs11_log(ctx, LOG_DEBUG,
				"Unsupported CKA_CERTIFICATE_TYPE attribute value: %lu\n",
				(unsigned long)cert_type);
			return NULL;
		}
		break;
	default:
		return NULL;
	}

	obj = OPENSSL_malloc(sizeof(*obj));
	if (!obj)
		return NULL;

	memset(obj, 0, sizeof(*obj));
	obj->refcnt = 1;
	pthread_mutex_init(&obj->lock, 0);
	obj->object_class = object_class;
	obj->object = object;
	obj->slot = pkcs11_slot_ref(slot);
	obj->id_len = sizeof(obj->id);
	if (pkcs11_getattr_var(ctx, session, object, CKA_ID, obj->id, &obj->id_len)) {
		pkcs11_log(ctx, LOG_DEBUG, "Missing CKA_ID attribute\n");
		obj->id_len = 0;
	}
	if (pkcs11_getattr_alloc(ctx, session, object, CKA_LABEL, (CK_BYTE **)&obj->label, NULL))
		pkcs11_log(ctx, LOG_DEBUG, "Missing CKA_LABEL attribute\n");
	obj->ops = ops;
	obj->forkid = get_forkid();
	switch (object_class) {
	case CKO_PRIVATE_KEY:
		if (pkcs11_getattr_val(ctx, session, object, CKA_ALWAYS_AUTHENTICATE,
				&obj->always_authenticate, sizeof(CK_BBOOL))) {
			pkcs11_log(ctx, LOG_DEBUG, "Missing CKA_ALWAYS_AUTHENTICATE attribute\n");
		}
		break;
	case CKO_CERTIFICATE:
		if (pkcs11_getattr_alloc(ctx, session, object, CKA_VALUE, &data, &size)) {
			pkcs11_log(ctx, LOG_DEBUG, "Missing CKA_VALUE attribute\n");
		} else {
			const unsigned char *p = data;

			obj->x509 = d2i_X509(NULL, &p, (long)size);
			OPENSSL_free(data);
		}
		break;
	}
	return obj;
}

/* Get object based on template */
PKCS11_OBJECT_private *pkcs11_object_from_template(PKCS11_SLOT_private *slot,
	CK_SESSION_HANDLE session, PKCS11_TEMPLATE *tmpl)
{
	PKCS11_OBJECT_private *obj = NULL;
	CK_OBJECT_HANDLE object_handle;
	int release = 0;

	if (session == CK_INVALID_HANDLE) {
		if (pkcs11_get_session(slot, 0, &session))
			return NULL;
		release = 1;
	}

	object_handle = pkcs11_handle_from_template(slot, session, tmpl);
	if (object_handle)
		obj = pkcs11_object_from_handle(slot, session, object_handle);

	if (release)
		pkcs11_put_session(slot, session);

	return obj;
}

PKCS11_OBJECT_private *pkcs11_object_from_object(PKCS11_OBJECT_private *obj,
	CK_SESSION_HANDLE session, CK_OBJECT_CLASS object_class)
{
	PKCS11_TEMPLATE tmpl = {0};
	pkcs11_addattr_var(&tmpl, CKA_CLASS, object_class);
	pkcs11_addattr(&tmpl, CKA_ID, obj->id, obj->id_len);
	return pkcs11_object_from_template(obj->slot, session, &tmpl);
}

void pkcs11_object_free(PKCS11_OBJECT_private *obj)
{
	if (!obj)
		return;

	if (pkcs11_atomic_add(&obj->refcnt, -1, &obj->lock) != 0)
		return;
	pkcs11_slot_unref(obj->slot);
	X509_free(obj->x509);
	OPENSSL_free(obj->label);
	pthread_mutex_destroy(&obj->lock);
	OPENSSL_free(obj);
}

/* Set UI method to allow retrieving CKU_CONTEXT_SPECIFIC PINs interactively */
int pkcs11_set_ui_method(PKCS11_CTX_private *ctx,
		UI_METHOD *ui_method, void *ui_user_data)
{
	if (!ctx)
		return -1;
	ctx->ui_method = ui_method;
	ctx->ui_user_data = ui_user_data;
	return 0;
}

/*
 * Find private key matching a certificate
 */
PKCS11_KEY *pkcs11_find_key(PKCS11_OBJECT_private *cert)
{
	PKCS11_KEY *keys, key_template = {0};
	unsigned int n, count;
	key_template.isPrivate = 1;

	key_template.id = cert->id;
	key_template.id_len = cert->id_len;

	if (pkcs11_enumerate_keys(cert->slot, CKO_PRIVATE_KEY, &key_template, &keys, &count))
		return NULL;
	for (n = 0; n < count; n++) {
		PKCS11_OBJECT_private *kpriv = PRIVKEY(&keys[n]);
		if (kpriv && cert->id_len == kpriv->id_len
				&& !memcmp(cert->id, kpriv->id, cert->id_len))
			return &keys[n];
	}
	return NULL;
}

/*
 * Reopens the object by refresing the object handle
 */
int pkcs11_reload_object(PKCS11_OBJECT_private *obj)
{
	PKCS11_SLOT_private *slot = obj->slot;
	CK_SESSION_HANDLE session;
	PKCS11_TEMPLATE tmpl = {0};

	if (pkcs11_get_session(slot, 0, &session))
		return -1;

	pkcs11_addattr_var(&tmpl, CKA_CLASS, obj->object_class);
	if (obj->id_len)
		pkcs11_addattr(&tmpl, CKA_ID, obj->id, obj->id_len);
	if (obj->label)
		pkcs11_addattr_s(&tmpl, CKA_LABEL, obj->label);

	obj->object = pkcs11_handle_from_template(slot, session, &tmpl);
	pkcs11_put_session(slot, session);

	if (obj->object == CK_INVALID_HANDLE)
		CRYPTOKI_checkerr(CKR_F_PKCS11_RELOAD_KEY, CKR_OBJECT_HANDLE_INVALID);

	return 0;
}

/**
 * Generate RSA key pair directly on token
 */
int pkcs11_rsa_keygen(PKCS11_SLOT_private *slot, unsigned int bits,
		const char *label, const unsigned char *id, size_t id_len,
		const PKCS11_params *params)
{
	PKCS11_CTX_private *ctx = slot->ctx;
	CK_SESSION_HANDLE session;
	PKCS11_TEMPLATE pubtmpl = {0}, privtmpl = {0};
	CK_MECHANISM mechanism = {
		CKM_RSA_PKCS_KEY_PAIR_GEN, NULL_PTR, 0
	};
	CK_ULONG num_bits = bits;
	CK_BYTE public_exponent[] = { 1, 0, 1 };
	CK_OBJECT_HANDLE pub_key_obj, priv_key_obj;
	int rv;

	if (pkcs11_init_keygen(slot, &session))
		return -1;

	/* The following attributes are necessary for RSA encryption and DSA */
	/* pubkey attributes */
	pkcs11_common_pubkey_attr(&pubtmpl, label, id, id_len);
	pkcs11_addattr_bool(&pubtmpl, CKA_ENCRYPT, TRUE);
	pkcs11_addattr_var(&pubtmpl, CKA_MODULUS_BITS, num_bits);
	pkcs11_addattr(&pubtmpl, CKA_PUBLIC_EXPONENT,
		public_exponent, sizeof public_exponent);

	/* privkey attributes */
	pkcs11_common_privkey_attr(&privtmpl, label, id, id_len, params);
	pkcs11_addattr_bool(&privtmpl, CKA_DECRYPT, TRUE);

	/* call the pkcs11 module to create the key pair */
	rv = CRYPTOKI_call(ctx, C_GenerateKeyPair(
		session, &mechanism,
		pubtmpl.attrs, pubtmpl.nattr,
		privtmpl.attrs, privtmpl.nattr,
		&pub_key_obj, &priv_key_obj));
	pkcs11_put_session(slot, session);

	/* zap all memory allocated when building the template */
	pkcs11_zap_attrs(&privtmpl);
	pkcs11_zap_attrs(&pubtmpl);

	CRYPTOKI_checkerr(CKR_F_PKCS11_GENERATE_KEY, rv);

	return 0;
}

#ifndef OPENSSL_NO_EC

/**
 * Generate EC key pair directly on token
 */
int pkcs11_ec_keygen(PKCS11_SLOT_private *slot, const char *curve,
		const char *label, const unsigned char *id, size_t id_len,
		const PKCS11_params *params)
{
	PKCS11_CTX_private *ctx = slot->ctx;
	CK_SESSION_HANDLE session;
	PKCS11_TEMPLATE pubtmpl = {0}, privtmpl = {0};
	CK_MECHANISM mechanism = {
		CKM_EC_KEY_PAIR_GEN, NULL_PTR, 0
	};
	CK_OBJECT_HANDLE pub_key_obj, priv_key_obj;
	int rv;
	unsigned char *ec_params = NULL;
	int ec_params_len = 0;
	unsigned char *tmp = NULL;
	ASN1_OBJECT *curve_obj = NULL;
	int curve_nid = NID_undef;

	if (pkcs11_init_keygen(slot, &session))
		return -1;

	curve_nid = EC_curve_nist2nid(curve);
	if (curve_nid == NID_undef)
		curve_nid = OBJ_sn2nid(curve);
	if (curve_nid == NID_undef)
		curve_nid = OBJ_ln2nid(curve);
	if (curve_nid == NID_undef)
		return -1;
	curve_obj = OBJ_nid2obj(curve_nid);
	if (!curve_obj)
		return -1;
	/* convert to DER format and take just the length */
	ec_params_len = i2d_ASN1_OBJECT(curve_obj, NULL);
	if (ec_params_len < 0)
		return -1;
	ec_params = OPENSSL_malloc(ec_params_len);
	if (!ec_params)
		return -1;
	/**
	 * ec_params points to beginning of DER encoded object. Since we need this
	 * location later and OpenSSL changes it in i2d_ASN1_OBJECT to point to 1 byte
	 * after DER encoded object, we assign the pointer to temporary throw-away
	 * pointer tmp
	 */
	tmp = ec_params;
	if (i2d_ASN1_OBJECT(curve_obj, &tmp) < 0)
		return -1;

	/* The following attributes are necessary for ECDSA and ECDH mechanisms */
	/* pubkey attributes */
	pkcs11_common_pubkey_attr(&pubtmpl, label, id, id_len);
	pkcs11_addattr(&pubtmpl, CKA_EC_PARAMS, ec_params, ec_params_len);

	/* privkey attributes */
	pkcs11_common_privkey_attr(&privtmpl, label, id, id_len, params);
	pkcs11_addattr_bool(&privtmpl, CKA_DERIVE, TRUE);

	/* call the pkcs11 module to create the key pair */
	rv = CRYPTOKI_call(ctx, C_GenerateKeyPair(
			session, &mechanism,
			pubtmpl.attrs, pubtmpl.nattr,
			privtmpl.attrs, privtmpl.nattr,
			&pub_key_obj, &priv_key_obj));
	pkcs11_put_session(slot, session);

	/* zap all memory allocated when building the template */
	pkcs11_zap_attrs(&privtmpl);
	pkcs11_zap_attrs(&pubtmpl);
	memset(ec_params, 0, ec_params_len);
	OPENSSL_free(ec_params);

	CRYPTOKI_checkerr(CKR_F_PKCS11_GENERATE_KEY, rv);
	return 0;
}

# if OPENSSL_VERSION_NUMBER >= 0x30000000L
/**
 * Generate EdDSA (Ed25519 / Ed448) key pair directly on token
 */
int pkcs11_eddsa_keygen(PKCS11_SLOT_private *slot,
	int nid, const char *label, const unsigned char *id,
	size_t id_len, const PKCS11_params *params)
{
	PKCS11_CTX_private *ctx = slot->ctx;
	CK_SESSION_HANDLE session;
	PKCS11_TEMPLATE pubtmpl = {0}, privtmpl = {0};
	CK_MECHANISM mechanism = {
		CKM_EC_EDWARDS_KEY_PAIR_GEN, NULL_PTR, 0
	};
	CK_OBJECT_HANDLE pub_key_obj, priv_key_obj;
	int rv;
	unsigned char *eddsa_params = NULL;
	size_t eddsa_params_len = 0;

	if (pkcs11_init_keygen(slot, &session))
		return -1;

	if (nid == NID_ED25519) {
		eddsa_params = (unsigned char *)OID_ED25519;
		eddsa_params_len = sizeof(OID_ED25519);
	} else if (nid == NID_ED448) {
		eddsa_params = (unsigned char *)OID_ED448;
		eddsa_params_len = sizeof(OID_ED448);
	} else {
		return -1; /* unsupported */
	}

	/* public key attributes */
	pkcs11_common_pubkey_attr(&pubtmpl, label, id, id_len);
	pkcs11_addattr(&pubtmpl, CKA_EC_PARAMS, eddsa_params, eddsa_params_len);
	pkcs11_addattr_bool(&pubtmpl, CKA_VERIFY, TRUE);

	/* private key attributes */
	pkcs11_common_privkey_attr(&privtmpl, label, id, id_len, params);
	pkcs11_addattr_bool(&privtmpl, CKA_SIGN, TRUE);

	/* generate key pair */
	rv = CRYPTOKI_call(ctx, C_GenerateKeyPair(
		session, &mechanism,
		pubtmpl.attrs, pubtmpl.nattr,
		privtmpl.attrs, privtmpl.nattr,
		&pub_key_obj, &priv_key_obj));
	pkcs11_put_session(slot, session);

	/* cleanup */
	pkcs11_zap_attrs(&privtmpl);
	pkcs11_zap_attrs(&pubtmpl);

	CRYPTOKI_checkerr(CKR_F_PKCS11_GENERATE_KEY, rv);
	return 0;
}
# endif /* OPENSSL_VERSION_NUMBER >= 0x30000000L */
#endif /* OPENSSL_NO_EC */

/*
 * Store a private key on the token
 */
int pkcs11_store_private_key(PKCS11_SLOT_private *slot, EVP_PKEY *pk,
		char *label, unsigned char *id, size_t id_len)
{
	if (pkcs11_store_key(slot, pk, CKO_PRIVATE_KEY, label, id, id_len, NULL))
		return -1;
	return 0;
}

int pkcs11_store_public_key(PKCS11_SLOT_private *slot, EVP_PKEY *pk,
		char *label, unsigned char *id, size_t id_len)
{
	if (pkcs11_store_key(slot, pk, CKO_PUBLIC_KEY, label, id, id_len, NULL))
		return -1;
	return 0;
}

/*
 * Store private key
 */
static int pkcs11_store_key(PKCS11_SLOT_private *slot, EVP_PKEY *pk,
		CK_OBJECT_CLASS type, char *label, unsigned char *id, size_t id_len,
		PKCS11_KEY **ret_key)
{
	PKCS11_CTX_private *ctx = slot->ctx;
	PKCS11_TEMPLATE tmpl = {0};
	CK_SESSION_HANDLE session;
	CK_OBJECT_HANDLE object;
	CK_KEY_TYPE key_type_rsa = CKK_RSA;
	int rv, r = -1;
	const BIGNUM *rsa_n, *rsa_e, *rsa_d, *rsa_p, *rsa_q, *rsa_dmp1, *rsa_dmq1, *rsa_iqmp;

	/* Now build the key attrs */
	pkcs11_addattr_var(&tmpl, CKA_CLASS, type);
	if (label)
		pkcs11_addattr_s(&tmpl, CKA_LABEL, label);
	if (id && id_len)
		pkcs11_addattr(&tmpl, CKA_ID, id, id_len);
	pkcs11_addattr_bool(&tmpl, CKA_TOKEN, TRUE);
	if (type == CKO_PRIVATE_KEY) {
		pkcs11_addattr_bool(&tmpl, CKA_PRIVATE, TRUE);
		pkcs11_addattr_bool(&tmpl, CKA_SENSITIVE, TRUE);
		pkcs11_addattr_bool(&tmpl, CKA_DECRYPT, TRUE);
		pkcs11_addattr_bool(&tmpl, CKA_SIGN, TRUE);
		pkcs11_addattr_bool(&tmpl, CKA_UNWRAP, TRUE);
	} else { /* CKO_PUBLIC_KEY */
		pkcs11_addattr_bool(&tmpl, CKA_ENCRYPT, TRUE);
		pkcs11_addattr_bool(&tmpl, CKA_VERIFY, TRUE);
		pkcs11_addattr_bool(&tmpl, CKA_WRAP, TRUE);
	}
#if OPENSSL_VERSION_NUMBER >= 0x10100003L || ( defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER >= 0x3050000fL )
	if (EVP_PKEY_base_id(pk) == EVP_PKEY_RSA) {
		RSA *rsa = EVP_PKEY_get1_RSA(pk);
		RSA_get0_key(rsa, &rsa_n, &rsa_e, &rsa_d);
		RSA_get0_factors(rsa, &rsa_p, &rsa_q);
		RSA_get0_crt_params(rsa, &rsa_dmp1, &rsa_dmq1, &rsa_iqmp);
		RSA_free(rsa);
#else
	if (pk->type == EVP_PKEY_RSA) {
		RSA *rsa = pk->pkey.rsa;
		rsa_n = rsa->n;
		rsa_e = rsa->e;
		rsa_d = rsa->d;
		rsa_p = rsa->p;
		rsa_q = rsa->q;
		rsa_dmp1 = rsa->dmp1;
		rsa_dmq1 = rsa->dmq1;
		rsa_iqmp = rsa->iqmp;
#endif
		pkcs11_addattr_var(&tmpl, CKA_KEY_TYPE, key_type_rsa);
		pkcs11_addattr_bn(&tmpl, CKA_MODULUS, rsa_n);
		pkcs11_addattr_bn(&tmpl, CKA_PUBLIC_EXPONENT, rsa_e);
		if (type == CKO_PRIVATE_KEY) {
			pkcs11_addattr_bn(&tmpl, CKA_PRIVATE_EXPONENT, rsa_d);
			pkcs11_addattr_bn(&tmpl, CKA_PRIME_1, rsa_p);
			pkcs11_addattr_bn(&tmpl, CKA_PRIME_2, rsa_q);
			if (rsa_dmp1)
				pkcs11_addattr_bn(&tmpl, CKA_EXPONENT_1, rsa_dmp1);
			if (rsa_dmq1)
				pkcs11_addattr_bn(&tmpl, CKA_EXPONENT_2, rsa_dmq1);
			if (rsa_iqmp)
				pkcs11_addattr_bn(&tmpl, CKA_COEFFICIENT, rsa_iqmp);
		}
	} else {
		pkcs11_zap_attrs(&tmpl);
		P11err(P11_F_PKCS11_STORE_KEY, P11_R_NOT_SUPPORTED);
		return -1;
	}

	if (pkcs11_get_session(slot, 1, &session)) {
		pkcs11_zap_attrs(&tmpl);
		return -1;
	}

	/* Now call the pkcs11 module to create the object */
	rv = CRYPTOKI_call(ctx, C_CreateObject(session, tmpl.attrs, tmpl.nattr, &object));

	/* Zap all memory allocated when building the template */
	pkcs11_zap_attrs(&tmpl);

	if (rv == CKR_OK) {
		/* Gobble the key object */
		r = pkcs11_init_key(slot, session, object, type, ret_key);
	}
	pkcs11_put_session(slot, session);

	CRYPTOKI_checkerr(CKR_F_PKCS11_STORE_KEY, rv);
	return r;

}

/*
 * Get the key type
 */
int pkcs11_get_key_type(PKCS11_OBJECT_private *key)
{
	if (key->ops)
		return key->ops->pkey_type;
	return EVP_PKEY_NONE;
}

/*
 * Create an EVP_PKEY OpenSSL object for a given key
 * Returns the key type specified in object_class.
 */
EVP_PKEY *pkcs11_get_key(PKCS11_OBJECT_private *key0, CK_OBJECT_CLASS object_class)
{
	PKCS11_OBJECT_private *key = key0;
	EVP_PKEY *ret = NULL;
	RSA *rsa;
#if OPENSSL_VERSION_NUMBER < 0x30000000L || defined(LIBRESSL_VERSION_NUMBER)
	EC_KEY *ec_key;
#endif

	if (key->object_class != object_class)
		key = pkcs11_object_from_object(key, CK_INVALID_HANDLE, object_class);
	if (!key || !key->ops)
		goto err;
	if (!key->evp_key) {
		key->evp_key = key->ops->get_evp_key(key);
		if (!key->evp_key)
			goto err;
	}
	/* We need a full copy of the EVP_PKEY as it will be modified later.
	 * Using a reference would mean changes to the duplicated EVP_PKEY could
	 * affect the original one.
	 */
	switch (EVP_PKEY_base_id(key->evp_key)) {
	case EVP_PKEY_RSA:
		/* Do not try to duplicate foreign RSA keys */
		rsa = EVP_PKEY_get1_RSA(key->evp_key);
		if (!rsa)
			goto err;
		ret = EVP_PKEY_new();
		if (!ret) {
			RSA_free(rsa);
			goto err;
		}
		if (!EVP_PKEY_assign_RSA(ret, rsa)) {
			RSA_free(rsa);
			EVP_PKEY_free(ret);
			goto err;
		}
		if (key->object_class != CKO_PRIVATE_KEY)
			pkcs11_set_ex_data_rsa(rsa, NULL);
		break;
	case EVP_PKEY_EC:
#if OPENSSL_VERSION_NUMBER < 0x30000000L || defined(LIBRESSL_VERSION_NUMBER)
		ec_key = EVP_PKEY_get1_EC_KEY(key->evp_key);
		if (!ec_key)
			goto err;
		ret = EVP_PKEY_new();
		if (!ret) {
			EC_KEY_free(ec_key);
			goto err;
		}
		if (!EVP_PKEY_assign_EC_KEY(ret, ec_key)) {
			EC_KEY_free(ec_key);
			EVP_PKEY_free(ret);
			goto err;
		}
		if (key->object_class != CKO_PRIVATE_KEY)
			pkcs11_set_ex_data_ec(ec_key, NULL);
#else
		/* pkcs11_ec_copy() method is only set for private keys,
		 * so public keys do not have a PKCS11_OBJECT reference */
		ret = EVP_PKEY_dup(key->evp_key);
#endif
		break;
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
	case EVP_PKEY_ED25519:
	case EVP_PKEY_ED448:
		ret = key->evp_key;
		EVP_PKEY_up_ref(key->evp_key);
		break;
#endif /* OPENSSL_VERSION_NUMBER >= 0x30000000L */
	default:
		pkcs11_log(key0->slot->ctx, LOG_DEBUG, "Unsupported key type\n");
	}
err:
	if (key != key0)
		pkcs11_object_free(key);
	return ret;
}

/*
 * Authenticate a private the key operation if needed
 * This function *only* handles CKU_CONTEXT_SPECIFIC logins.
 */
int pkcs11_authenticate(PKCS11_OBJECT_private *key, CK_SESSION_HANDLE session)
{
	PKCS11_SLOT_private *slot = key->slot;
	PKCS11_CTX_private *ctx = slot->ctx;
	char pin[MAX_PIN_LENGTH+1];
	char *prompt;
	UI *ui;
	int rv;

	/* Handle CKF_PROTECTED_AUTHENTICATION_PATH */
	if (slot->secure_login) {
		rv = CRYPTOKI_call(ctx,
			C_Login(session, CKU_CONTEXT_SPECIFIC, NULL, 0));
		return rv == CKR_USER_ALREADY_LOGGED_IN ? 0 : rv;
	}

	/* Call UI to ask for a PIN */
	ui = UI_new_method(ctx->ui_method);
	if (!ui)
		return P11_R_UI_FAILED;
	if (ctx->ui_user_data)
		UI_add_user_data(ui, ctx->ui_user_data);
	memset(pin, 0, MAX_PIN_LENGTH+1);
	prompt = UI_construct_prompt(ui, "PKCS#11 key PIN", key->label);
	if (!prompt) {
		return P11_R_UI_FAILED;
	}
	if (UI_dup_input_string(ui, prompt,
			UI_INPUT_FLAG_DEFAULT_PWD, pin, 4, MAX_PIN_LENGTH) <= 0) {
		UI_free(ui);
		OPENSSL_free(prompt);
		return P11_R_UI_FAILED;
	}
	OPENSSL_free(prompt);

	if (UI_process(ui)) {
		UI_free(ui);
		return P11_R_UI_FAILED;
	}
	UI_free(ui);

	/* Login with the PIN */
	rv = CRYPTOKI_call(ctx,
		C_Login(session, CKU_CONTEXT_SPECIFIC,
			(CK_UTF8CHAR *)pin, (CK_ULONG)strlen(pin)));
	OPENSSL_cleanse(pin, MAX_PIN_LENGTH+1);
	return rv == CKR_USER_ALREADY_LOGGED_IN ? 0 : rv;
}

/*
 * Return keys of a given type (public or private) matching the key_template
 * Use the cached values if available
 */
int pkcs11_enumerate_keys(PKCS11_SLOT_private *slot, unsigned int type, const PKCS11_KEY *key_template,
		PKCS11_KEY **keyp, unsigned int *countp)
{
	PKCS11_keys *keys = (type == CKO_PRIVATE_KEY) ? &slot->prv : &slot->pub;
	PKCS11_TEMPLATE tmpl = {0};
	CK_SESSION_HANDLE session;
	CK_OBJECT_CLASS object_class = type;
	int rv;

	pkcs11_addattr_var(&tmpl, CKA_CLASS, object_class);
	if (key_template) {
		if (key_template->id_len)
			pkcs11_addattr(&tmpl, CKA_ID, key_template->id, key_template->id_len);

		if (key_template->label)
			pkcs11_addattr_s(&tmpl, CKA_LABEL, key_template->label);
	}
	if (pkcs11_get_session(slot, 0, &session))
		return -1;

	rv = pkcs11_find_keys(slot, session, type, &tmpl);
	pkcs11_put_session(slot, session);
	if (rv < 0) {
		pkcs11_destroy_keys(slot, type);
		return -1;
	}
	if (keyp)
		*keyp = keys->keys;
	if (countp)
		*countp = keys->num;
	return 0;
}

/**
 * Remove an object from the associated token
 */
int pkcs11_remove_object(PKCS11_OBJECT_private *obj)
{
	PKCS11_SLOT_private *slot = obj->slot;
	PKCS11_CTX_private *ctx = slot->ctx;
	CK_SESSION_HANDLE session;
	int rv;

	if (pkcs11_get_session(slot, 1, &session))
		return -1;

	rv = CRYPTOKI_call(ctx, C_DestroyObject(session, obj->object));
	pkcs11_put_session(slot, session);
	CRYPTOKI_checkerr(CKR_F_PKCS11_REMOVE_KEY, rv);

	return 0;
}

/*
 * Find all keys of a given type (public or private) matching template
 */
static int pkcs11_find_keys(PKCS11_SLOT_private *slot, CK_SESSION_HANDLE session, unsigned int type, PKCS11_TEMPLATE *tmpl)
{
	PKCS11_CTX_private *ctx = slot->ctx;
	int rv, res = -1;

	/* Tell the PKCS11 lib to enumerate all matching objects */
	rv = CRYPTOKI_call(ctx,
		C_FindObjectsInit(session, tmpl->attrs, tmpl->nattr));
	CRYPTOKI_checkerr(CKR_F_PKCS11_FIND_KEYS, rv);

	do {
		res = pkcs11_next_key(ctx, slot, session, type);
	} while (res == 0);

	CRYPTOKI_call(ctx, C_FindObjectsFinal(session));

	return (res < 0) ? -1 : 0;
}

static int pkcs11_next_key(PKCS11_CTX_private *ctx, PKCS11_SLOT_private *slot,
		CK_SESSION_HANDLE session, CK_OBJECT_CLASS type)
{
	CK_OBJECT_HANDLE obj;
	CK_ULONG count;
	int rv;

	/* Get the next matching object */
	rv = CRYPTOKI_call(ctx, C_FindObjects(session, &obj, 1, &count));
	CRYPTOKI_checkerr(CKR_F_PKCS11_NEXT_KEY, rv);

	if (count == 0)
		return 1;

	if (pkcs11_init_key(slot, session, obj, type, NULL))
		return -1;

	return 0;
}

PKCS11_OBJECT_private *pkcs11_object_ref(PKCS11_OBJECT_private *obj)
{
	pkcs11_atomic_add(&obj->refcnt, 1, &obj->lock);
	return obj;
}

static int pkcs11_init_key(PKCS11_SLOT_private *slot, CK_SESSION_HANDLE session,
	CK_OBJECT_HANDLE object, CK_OBJECT_CLASS type, PKCS11_KEY **ret)
{
	PKCS11_keys *keys = (type == CKO_PRIVATE_KEY) ? &slot->prv : &slot->pub;
	PKCS11_OBJECT_private *kpriv;
	PKCS11_KEY *key, *tmp;
	int i;

	/* Prevent re-adding existing PKCS#11 object handles */
	/* TODO: Rewrite the O(n) algorithm as O(log n),
	 * or it may be too slow with a large number of keys */
	for (i = 0; i < keys->num; ++i) {
		if (PRIVKEY(&keys->keys[i])->object == object) {
			if (ret)
				*ret = &keys->keys[i];
			return 0;
		}
	}

	kpriv = pkcs11_object_from_handle(slot, session, object);
	if (!kpriv)
		return -1;

	/* Allocate memory */
	tmp = OPENSSL_realloc(keys->keys, (keys->num + 1) * sizeof(PKCS11_KEY));
	if (!tmp) {
		pkcs11_object_free(kpriv);
		return -1;
	}
	keys->keys = tmp;
	key = keys->keys + keys->num++;
	memset(key, 0, sizeof(PKCS11_KEY));

	/* Fill public properties */
	key->_private = kpriv;
	key->id = kpriv->id;
	key->id_len = kpriv->id_len;
	key->label = kpriv->label;
	key->isPrivate = (type == CKO_PRIVATE_KEY);

	if (ret)
		*ret = key;
	return 0;
}

static int pkcs11_init_keygen(PKCS11_SLOT_private *slot, CK_SESSION_HANDLE *session)
{
	pthread_mutex_lock(&slot->lock);
	/* R/W session is mandatory for key generation. */
	if (slot->rw_mode != 1) {
		pthread_mutex_unlock(&slot->lock);
		if (pkcs11_open_session(slot, 1))
			return -1;
		/* open_session will call C_CloseAllSessions which logs everyone out */
		if (pkcs11_login(slot, 0, slot->prev_pin))
			return -1;
	}
	pthread_mutex_unlock(&slot->lock);
	return pkcs11_get_session(slot, 1, session);
}

static void pkcs11_common_pubkey_attr(PKCS11_TEMPLATE *pubtmpl,
		const char *label, const unsigned char *id, size_t id_len)
{
	/* Common pubkey attributes */
	pkcs11_addattr(pubtmpl, CKA_ID, (void *)id, id_len);
	if (label)
		pkcs11_addattr_s(pubtmpl, CKA_LABEL, label);
	pkcs11_addattr_bool(pubtmpl, CKA_TOKEN, TRUE);
	pkcs11_addattr_bool(pubtmpl, CKA_VERIFY, TRUE);
	pkcs11_addattr_bool(pubtmpl, CKA_WRAP, TRUE);
}

static void pkcs11_common_privkey_attr(PKCS11_TEMPLATE *privtmpl,
		const char *label, const unsigned char *id, size_t id_len,
		const PKCS11_params *params)
{
	/* Common privkey attributes */
	pkcs11_addattr(privtmpl, CKA_ID, (void *)id, id_len);
	if (label)
		pkcs11_addattr_s(privtmpl, CKA_LABEL, label);
	pkcs11_addattr_bool(privtmpl, CKA_PRIVATE, TRUE);
	pkcs11_addattr_bool(privtmpl, CKA_TOKEN, TRUE);
	pkcs11_addattr_bool(privtmpl, CKA_SENSITIVE, params->sensitive);
	pkcs11_addattr_bool(privtmpl, CKA_EXTRACTABLE, params->extractable);
	pkcs11_addattr_bool(privtmpl, CKA_SIGN, TRUE);
	pkcs11_addattr_bool(privtmpl, CKA_UNWRAP, TRUE);
}

/*
 * Destroy all keys of a given type (public or private)
 */
void pkcs11_destroy_keys(PKCS11_SLOT_private *slot, unsigned int type)
{
	PKCS11_keys *keys = (type == CKO_PRIVATE_KEY) ? &slot->prv : &slot->pub;

	while (keys->num > 0) {
		PKCS11_KEY *key = &keys->keys[--keys->num];
		PKCS11_OBJECT_private *obj = PRIVKEY(key);

		if (obj) {
			EVP_PKEY_free(obj->evp_key);
			pkcs11_object_free(obj);
		}
	}
	if (keys->keys)
		OPENSSL_free(keys->keys);
	keys->keys = NULL;
	keys->num = 0;
}

#if OPENSSL_VERSION_NUMBER < 0x40000000L
# if OPENSSL_VERSION_NUMBER >= 0x30000000L
void pkcs11_set_ex_data_pkey(EVP_PKEY *pkey, PKCS11_OBJECT_private *key)
{
	EVP_PKEY_set_ex_data(pkey, pkey_ex_index, key);
}

PKCS11_OBJECT_private *pkcs11_get_ex_data_pkey(const EVP_PKEY *pkey)
{
	return EVP_PKEY_get_ex_data(pkey, pkey_ex_index);
}

void alloc_pkey_ex_index(void)
{
	if (pkey_ex_index == 0) {
		while (pkey_ex_index == 0) /* Workaround for OpenSSL RT3710 */
			pkey_ex_index = EVP_PKEY_get_ex_new_index(0, "libp11 EVP_PKEY",
				NULL, NULL, NULL);
		if (pkey_ex_index < 0)
			pkey_ex_index = 0; /* Fallback to app_data */
	}
}

void free_pkey_ex_index(void)
{
	if (pkey_ex_index > 0) {
		CRYPTO_free_ex_index(CRYPTO_EX_INDEX_EVP_PKEY, pkey_ex_index);
		pkey_ex_index = 0;
	}
}
# endif /* OPENSSL_VERSION_NUMBER >= 0x30000000L */

static CK_MECHANISM_TYPE pkcs11_md2ckm(const EVP_MD *md)
{
	switch (EVP_MD_type(md)) {
	case NID_sha1:
		return CKM_SHA_1;
	case NID_sha224:
		return CKM_SHA224;
	case NID_sha256:
		return CKM_SHA256;
	case NID_sha512:
		return CKM_SHA512;
	case NID_sha384:
		return CKM_SHA384;
#if OPENSSL_VERSION_NUMBER >= 0x10101000L && !defined(LIBRESSL_VERSION_NUMBER)
	case NID_sha3_224:
		return CKM_SHA3_224;
	case NID_sha3_256:
		return CKM_SHA3_256;
	case NID_sha3_384:
		return CKM_SHA3_384;
	case NID_sha3_512:
		return CKM_SHA3_512;
#endif
	default:
		return 0;
	}
}

static CK_RSA_PKCS_MGF_TYPE pkcs11_md2ckg(const EVP_MD *md)
{
	switch (EVP_MD_type(md)) {
	case NID_sha1:
		return CKG_MGF1_SHA1;
	case NID_sha224:
		return CKG_MGF1_SHA224;
	case NID_sha256:
		return CKG_MGF1_SHA256;
	case NID_sha512:
		return CKG_MGF1_SHA512;
	case NID_sha384:
		return CKG_MGF1_SHA384;
#if OPENSSL_VERSION_NUMBER >= 0x10101000L && !defined(LIBRESSL_VERSION_NUMBER)
	case NID_sha3_224:
		return CKG_MGF1_SHA3_224;
	case NID_sha3_256:
		return CKG_MGF1_SHA3_256;
	case NID_sha3_384:
		return CKG_MGF1_SHA3_384;
	case NID_sha3_512:
		return CKG_MGF1_SHA3_512;
#endif
	default:
		return 0;
	}
}

static int pkcs11_params_pss(CK_RSA_PKCS_PSS_PARAMS *pss,
		EVP_PKEY_CTX *ctx, PKCS11_CTX_private *pctx)
{
	const EVP_MD *sig_md, *mgf1_md;
	EVP_PKEY *evp_pkey;
	int salt_len;

	/* retrieve PSS parameters */
	if (EVP_PKEY_CTX_get_signature_md(ctx, &sig_md) <= 0)
		return -1;
	if (EVP_PKEY_CTX_get_rsa_mgf1_md(ctx, &mgf1_md) <= 0)
		return -1;
	if (!EVP_PKEY_CTX_get_rsa_pss_saltlen(ctx, &salt_len))
		return -1;
	switch (salt_len) {
	case -1:
		salt_len = EVP_MD_size(sig_md);
		break;
	case -2:
		if (!ctx)
			return -1;
		evp_pkey = EVP_PKEY_CTX_get0_pkey(ctx);
		if (!evp_pkey)
			return -1;
		salt_len = EVP_PKEY_size(evp_pkey) - EVP_MD_size(sig_md) - 2;
		if (((EVP_PKEY_bits(evp_pkey) - 1) & 0x7) == 0)
			salt_len--;
		if (salt_len < 0) /* integer underflow detected */
			return -1;
	}
	pkcs11_log(pctx, LOG_DEBUG, "salt_len=%d sig_md=%s mdf1_md=%s\n",
		salt_len, EVP_MD_name(sig_md), EVP_MD_name(mgf1_md));

	/* fill the CK_RSA_PKCS_PSS_PARAMS structure */
	memset(pss, 0, sizeof(CK_RSA_PKCS_PSS_PARAMS));
	pss->hashAlg = pkcs11_md2ckm(sig_md);
	pss->mgf = pkcs11_md2ckg(mgf1_md);
	if (!pss->hashAlg || !pss->mgf)
		return -1;
	pss->sLen = salt_len;
	return 0;
}

static int pkcs11_params_oaep(CK_RSA_PKCS_OAEP_PARAMS *oaep,
		EVP_PKEY_CTX *ctx, PKCS11_CTX_private *pctx)
{
	const EVP_MD *oaep_md, *mgf1_md;

	/* retrieve OAEP parameters */
	if (EVP_PKEY_CTX_get_rsa_oaep_md(ctx, &oaep_md) <= 0)
		return -1;
	if (EVP_PKEY_CTX_get_rsa_mgf1_md(ctx, &mgf1_md) <= 0)
		return -1;

	pkcs11_log(pctx, LOG_DEBUG, "oaep_md=%s mdf1_md=%s\n",
		EVP_MD_name(oaep_md), EVP_MD_name(mgf1_md));

	/* fill the CK_RSA_PKCS_OAEP_PARAMS structure */
	memset(oaep, 0, sizeof(CK_RSA_PKCS_OAEP_PARAMS));
	oaep->hashAlg = pkcs11_md2ckm(oaep_md);
	oaep->mgf = pkcs11_md2ckg(mgf1_md);
	if (!oaep->hashAlg || !oaep->mgf)
		return -1;
	/* we do not support the OAEP "label" parameter yet... */
	oaep->source = CKZ_DATA_SPECIFIED;
	oaep->pSourceData = NULL; /* empty parameter (label) */
	oaep->ulSourceDataLen = 0;
	return 0;
}

/* Attempt to sign using the PKCS#11-backed RSA implementation */
static int pkcs11_try_pkey_rsa_sign(EVP_PKEY_CTX *evp_pkey_ctx,
		unsigned char *sig, size_t *siglen,
		const unsigned char *tbs, size_t tbslen)
{
	EVP_PKEY *pkey;
	RSA *rsa;
	int rv = 0, padding;
	CK_ULONG size = (CK_ULONG)*siglen;
	PKCS11_OBJECT_private *key;
	PKCS11_SLOT_private *slot;
	PKCS11_CTX_private *ctx;
	const EVP_MD *sig_md;
	CK_SESSION_HANDLE session;
	CK_MECHANISM mechanism;
	CK_RSA_PKCS_PSS_PARAMS pss_params;

	/* RSA method has EVP_PKEY_FLAG_AUTOARGLEN set. OpenSSL core will handle
	 * the size inquiry internally. */
	if (!sig)
		return -1;
	if (!evp_pkey_ctx)
		return -1;

	pkey = EVP_PKEY_CTX_get0_pkey(evp_pkey_ctx);
	if (!pkey)
		return -1;
	rsa = (RSA *)EVP_PKEY_get0_RSA(pkey);
	if (!rsa)
		return -1;

	key = pkcs11_get_ex_data_rsa(rsa);
	if (check_object_fork(key) < 0)
		return -1;

	slot = key->slot;
	ctx = slot->ctx;
	if (!ctx)
		return -1;
#ifdef DEBUG
	pkcs11_log(ctx, LOG_DEBUG, "%s:%d pkcs11_try_pkey_rsa_sign() "
		"sig=%p *siglen=%lu tbs=%p tbslen=%lu\n",
		__FILE__, __LINE__, sig, *siglen, tbs, tbslen);
#endif
	if (EVP_PKEY_CTX_get_signature_md(evp_pkey_ctx, &sig_md) <= 0)
		return -1;
	if (tbslen != (size_t)EVP_MD_size(sig_md))
		return -1;

	memset(&mechanism, 0, sizeof mechanism);
	EVP_PKEY_CTX_get_rsa_padding(evp_pkey_ctx, &padding);
	switch (padding) {
	case RSA_PKCS1_PSS_PADDING:
		pkcs11_log(ctx, LOG_DEBUG, "padding=RSA_PKCS1_PSS_PADDING\n");
		if (pkcs11_params_pss(&pss_params, evp_pkey_ctx, ctx) < 0)
			return -1;
		mechanism.mechanism = CKM_RSA_PKCS_PSS;
		mechanism.pParameter = &pss_params;
		mechanism.ulParameterLen = sizeof pss_params;
		break;
	default:
		return -1;
	} /* end switch(padding) */

	if (pkcs11_get_session(slot, 0, &session))
		return -1;

	rv = CRYPTOKI_call(ctx,
		C_SignInit(session, &mechanism, key->object));
	if (rv != CKR_OK) {
		pkcs11_log(ctx, LOG_DEBUG, "%s:%d C_SignInit rv=%d\n",
			__FILE__, __LINE__, rv);
	} else if (key->always_authenticate == CK_TRUE)
		rv = pkcs11_authenticate(key, session);
	if (rv == CKR_OK) {
		rv = CRYPTOKI_call(ctx,
			C_Sign(session, (CK_BYTE_PTR)tbs, (CK_ULONG)tbslen, sig, &size));
		if (rv != CKR_OK) {
			pkcs11_log(ctx, LOG_DEBUG, "%s:%d C_Sign rv=%d\n",
				__FILE__, __LINE__, rv);
		}
	}
	pkcs11_put_session(slot, session);

	if (rv != CKR_OK)
		return -1;
	*siglen = size;
	return 1;
}

/* Attempt to decrypt using the PKCS#11-backed RSA implementation */
static int pkcs11_try_pkey_rsa_decrypt(EVP_PKEY_CTX *evp_pkey_ctx,
		unsigned char *out, size_t *outlen,
		const unsigned char *in, size_t inlen)
{
	EVP_PKEY *pkey;
	RSA *rsa;
	int rv = 0, padding;
	CK_ULONG size = (CK_ULONG)*outlen;
	PKCS11_OBJECT_private *key;
	PKCS11_SLOT_private *slot;
	PKCS11_CTX_private *ctx;
	CK_SESSION_HANDLE session;
	CK_MECHANISM mechanism;
	CK_RSA_PKCS_OAEP_PARAMS oaep_params;

	/* RSA method has EVP_PKEY_FLAG_AUTOARGLEN set. OpenSSL core will handle
	 * the size inquiry internally. */
	if (!out)
		return -1;
	if (!evp_pkey_ctx)
		return -1;

	pkey = EVP_PKEY_CTX_get0_pkey(evp_pkey_ctx);
	if (!pkey)
		return -1;
	rsa = (RSA *)EVP_PKEY_get0_RSA(pkey);
	if (!rsa)
		return -1;

	key = pkcs11_get_ex_data_rsa(rsa);
	if (check_object_fork(key) < 0)
		return -1;

	slot = key->slot;
	ctx = slot->ctx;
	if (!ctx)
		return -1;
#ifdef DEBUG
	pkcs11_log(ctx, LOG_DEBUG, "%s:%d pkcs11_try_pkey_rsa_decrypt() "
		"out=%p *outlen=%lu in=%p inlen=%lu\n",
		__FILE__, __LINE__, out, *outlen, in, inlen);
#endif
	memset(&mechanism, 0, sizeof mechanism);
	EVP_PKEY_CTX_get_rsa_padding(evp_pkey_ctx, &padding);
	switch (padding) {
	case RSA_PKCS1_OAEP_PADDING:
		pkcs11_log(ctx, LOG_DEBUG, "padding=RSA_PKCS1_OAEP_PADDING\n");
		if (pkcs11_params_oaep(&oaep_params, evp_pkey_ctx, ctx) < 0)
			return -1;
		mechanism.mechanism = CKM_RSA_PKCS_OAEP;
		mechanism.pParameter = &oaep_params;
		mechanism.ulParameterLen = sizeof oaep_params;
		break;
	case RSA_PKCS1_PADDING:
		pkcs11_log(ctx, LOG_DEBUG, "padding=RSA_PKCS1_PADDING\n");
		mechanism.mechanism = CKM_RSA_PKCS;
		mechanism.pParameter = NULL;
		mechanism.ulParameterLen = 0;
		break;
	default:
		pkcs11_log(ctx, LOG_DEBUG, "%s:%d unsupported padding: %d\n",
			__FILE__, __LINE__, padding);
		return -1;
	} /* end switch(padding) */

	if (pkcs11_get_session(slot, 0, &session))
		return -1;

	rv = CRYPTOKI_call(ctx,
		C_DecryptInit(session, &mechanism, key->object));
	if (rv != CKR_OK) {
		pkcs11_log(ctx, LOG_DEBUG, "%s:%d C_DecryptInit rv=%d\n",
			__FILE__, __LINE__, rv);
	} else if (key->always_authenticate == CK_TRUE)
		rv = pkcs11_authenticate(key, session);
	if (rv == CKR_OK) {
		rv = CRYPTOKI_call(ctx,
			C_Decrypt(session, (CK_BYTE_PTR)in, (CK_ULONG)inlen, out, &size));
		if (rv != CKR_OK) {
			pkcs11_log(ctx, LOG_DEBUG, "%s:%d C_Decrypt rv=%d\n",
				__FILE__, __LINE__, rv);
		}
	}
	pkcs11_put_session(slot, session);

	if (rv != CKR_OK)
		return -1;
	*outlen = size;
	return 1;
}

static int pkcs11_pkey_rsa_sign(EVP_PKEY_CTX *evp_pkey_ctx,
		unsigned char *sig, size_t *siglen,
		const unsigned char *tbs, size_t tbslen)
{
	int ret;

	ret = pkcs11_try_pkey_rsa_sign(evp_pkey_ctx, sig, siglen, tbs, tbslen);
	if (ret < 0)
		ret = (*orig_pkey_rsa_sign)(evp_pkey_ctx, sig, siglen, tbs, tbslen);
	return ret;
}

static int pkcs11_pkey_rsa_decrypt(EVP_PKEY_CTX *evp_pkey_ctx,
		unsigned char *out, size_t *outlen,
		const unsigned char *in, size_t inlen)
{
	int ret;

	ret = pkcs11_try_pkey_rsa_decrypt(evp_pkey_ctx, out, outlen, in, inlen);
	if (ret < 0)
		ret = (*orig_pkey_rsa_decrypt)(evp_pkey_ctx, out, outlen, in, inlen);
	return ret;
}

#if OPENSSL_VERSION_NUMBER < 0x100020d0L || defined(LIBRESSL_VERSION_NUMBER)
void EVP_PKEY_meth_get_sign(EVP_PKEY_METHOD *pmeth,
		int (**psign_init) (EVP_PKEY_CTX *ctx),
		int (**psign) (EVP_PKEY_CTX *ctx,
			unsigned char *sig, size_t *siglen,
			const unsigned char *tbs, size_t tbslen))
{
	if (psign_init)
		*psign_init = pmeth->sign_init;
	if (psign)
		*psign = pmeth->sign;
}

static void EVP_PKEY_meth_get_decrypt(EVP_PKEY_METHOD *pmeth,
		int (**pdecrypt_init) (EVP_PKEY_CTX *ctx),
		int (**pdecrypt) (EVP_PKEY_CTX *ctx,
			unsigned char *out,
			size_t *outlen,
			const unsigned char *in,
			size_t inlen))
{
	if (pdecrypt_init)
		*pdecrypt_init = pmeth->decrypt_init;
	if (pdecrypt)
		*pdecrypt = pmeth->decrypt;
}
#endif

EVP_PKEY_METHOD *pkcs11_pkey_method_rsa(void)
{
	EVP_PKEY_METHOD *new_meth_rsa = NULL;
	int orig_id;

	/* Cache the original EVP_PKEY_RSA method (once) */
	if (!orig_method_rsa)
#if OPENSSL_VERSION_NUMBER < 0x10101000L || defined(LIBRESSL_VERSION_NUMBER)
		orig_method_rsa = (EVP_PKEY_METHOD *)EVP_PKEY_meth_find(EVP_PKEY_RSA);
#else
		orig_method_rsa = EVP_PKEY_meth_find(EVP_PKEY_RSA);
#endif /* OPENSSL_VERSION_NUMBER < 0x10101000L || defined(LIBRESSL_VERSION_NUMBER) */

	if (!orig_method_rsa)
		return NULL;

	EVP_PKEY_meth_get0_info(&orig_id, NULL, orig_method_rsa);
	if (orig_id != EVP_PKEY_RSA)
		return NULL;

	EVP_PKEY_meth_get_sign(orig_method_rsa,
		&orig_pkey_rsa_sign_init, &orig_pkey_rsa_sign);
	if (!orig_pkey_rsa_sign)
		return NULL;

	EVP_PKEY_meth_get_decrypt(orig_method_rsa,
		&orig_pkey_rsa_decrypt_init, &orig_pkey_rsa_decrypt);
	if (!orig_pkey_rsa_decrypt)
		return NULL;

	new_meth_rsa = EVP_PKEY_meth_new(EVP_PKEY_RSA, EVP_PKEY_FLAG_AUTOARGLEN);
	if (!new_meth_rsa)
		return NULL;

	/* Duplicate the original method */
	EVP_PKEY_meth_copy(new_meth_rsa, orig_method_rsa);

	EVP_PKEY_meth_set_sign(new_meth_rsa,
		orig_pkey_rsa_sign_init, pkcs11_pkey_rsa_sign);
	EVP_PKEY_meth_set_decrypt(new_meth_rsa,
		orig_pkey_rsa_decrypt_init, pkcs11_pkey_rsa_decrypt);

	return new_meth_rsa;
}
#endif /* OPENSSL_VERSION_NUMBER < 0x40000000L */

/* vim: set noexpandtab: */
