open Asn.S
open Asn_grammars

(* This type really conflates three things: the set of pk algos that describe
 * the public key, the set of hashes, and the set of hash+pk algo combinations
 * that describe digests. The three are conflated because they are generated by
 * the same ASN grammar, AlgorithmIdentifier, to keep things close to the
 * standards.
 *
 * It's expected that downstream code with pick a subset and add a catch-all
 * that handles unsupported algos anyway.
 *)

type ec_curve =
  [ `SECP256R1 | `SECP384R1 | `SECP521R1 ]

let ec_curve_to_string = function
  | `SECP256R1 -> "SECP256R1"
  | `SECP384R1 -> "SECP384R1"
  | `SECP521R1 -> "SECP521R1"

type t =

  (* pk algos *)
  (* any more? is the universe big enough? ramsey's theorem for pk cyphers? *)
  | RSA
  | EC_pub of ec_curve

  (* sig algos *)
  | MD2_RSA
  | MD4_RSA
  | MD5_RSA
  | RIPEMD160_RSA
  | SHA1_RSA
  | SHA256_RSA
  | SHA384_RSA
  | SHA512_RSA
  | SHA224_RSA
  | ECDSA_SHA1
  | ECDSA_SHA224
  | ECDSA_SHA256
  | ECDSA_SHA384
  | ECDSA_SHA512

  | ED25519

  (* digest algorithms *)
  | MD2
  | MD4
  | MD5
  | SHA1
  | SHA256
  | SHA384
  | SHA512
  | SHA224
  | SHA512_224
  | SHA512_256

  (* HMAC algorithms *)
  | HMAC_SHA1
  | HMAC_SHA224
  | HMAC_SHA256
  | HMAC_SHA384
  | HMAC_SHA512

  (* symmetric block ciphers *)
  | AES128_CBC of string
  | AES192_CBC of string
  | AES256_CBC of string

  (* PBE encryption algorithms *)
  | SHA_RC4_128 of string * int
  | SHA_RC4_40 of string * int
  | SHA_3DES_CBC of string * int
  | SHA_2DES_CBC of string * int
  | SHA_RC2_128_CBC of string * int
  | SHA_RC2_40_CBC of string * int

  | PBKDF2 of string * int * int option * t
  | PBES2 of t * t

let to_string = function
  | RSA -> "RSA"
  | EC_pub curve -> ec_curve_to_string curve
  | MD2_RSA -> "RSA MD2"
  | MD4_RSA -> "RSA MD4"
  | MD5_RSA -> "RSA MD5"
  | RIPEMD160_RSA -> "RSA RIPEMD160"
  | SHA1_RSA -> "RSA SHA1"
  | SHA256_RSA -> "RSA SHA256"
  | SHA384_RSA -> "RSA SHA384"
  | SHA512_RSA -> "RSA SHA512"
  | SHA224_RSA -> "RSA SHA224"
  | ECDSA_SHA1 -> "ECDSA SHA1"
  | ECDSA_SHA224 -> "ECDSA SHA224"
  | ECDSA_SHA256 -> "ECDSA SHA256"
  | ECDSA_SHA384 -> "ECDSA SHA384"
  | ECDSA_SHA512 -> "ECDSA SHA512"
  | ED25519 -> "Ed25519"
  | MD2 -> "MD2"
  | MD4 -> "MD4"
  | MD5 -> "MD5"
  | SHA1 -> "SHA1"
  | SHA256 -> "SHA256"
  | SHA384 -> "SHA384"
  | SHA512 -> "SHA512"
  | SHA224 -> "SHA224"
  | SHA512_224 -> "SHA512/224"
  | SHA512_256 -> "SHA512/256"
  | HMAC_SHA1 -> "HMAC SHA1"
  | HMAC_SHA224 -> "HMAC SHA224"
  | HMAC_SHA256 -> "HMAC SHA256"
  | HMAC_SHA384 -> "HMAC SHA384"
  | HMAC_SHA512 -> "HMAC SHA512"
  | AES128_CBC _ -> "AES128 CBC"
  | AES192_CBC _ -> "AES192 CBC"
  | AES256_CBC _ -> "AES256 CBC"
  | SHA_RC4_128 (_, _) -> "PBES: SHA RC4 128"
  | SHA_RC4_40 (_, _) -> "PBES: SHA RC4 40"
  | SHA_3DES_CBC (_, _) -> "PBES: SHA 3DES CBC"
  | SHA_2DES_CBC (_, _) -> "PBES: SHA 2DES CBC"
  | SHA_RC2_128_CBC (_, _) -> "PBES: SHA RC2 128"
  | SHA_RC2_40_CBC (_, _) -> "PBES: SHA RC2 40"
  | PBKDF2 (_, _, _, _) -> "PBKDF2"
  | PBES2 (_, _) -> "PBES2"

let to_hash = function
  | MD5    -> Some `MD5
  | SHA1   -> Some `SHA1
  | SHA224 -> Some `SHA224
  | SHA256 -> Some `SHA256
  | SHA384 -> Some `SHA384
  | SHA512 -> Some `SHA512
  | _      -> None

and of_hash = function
  | `MD5    -> MD5
  | `SHA1   -> SHA1
  | `SHA224 -> SHA224
  | `SHA256 -> SHA256
  | `SHA384 -> SHA384
  | `SHA512 -> SHA512

and to_hmac = function
  | HMAC_SHA1 -> Some `SHA1
  | HMAC_SHA224 -> Some `SHA224
  | HMAC_SHA256 -> Some `SHA256
  | HMAC_SHA384 -> Some `SHA384
  | HMAC_SHA512 -> Some `SHA512
  | _ -> None

and of_hmac = function
  | `SHA1   -> HMAC_SHA1
  | `SHA224 -> HMAC_SHA224
  | `SHA256 -> HMAC_SHA256
  | `SHA384 -> HMAC_SHA384
  | `SHA512 -> HMAC_SHA512

and to_key_type = function
  | RSA        -> Some `RSA
  | EC_pub curve -> Some (`EC curve)
  | ED25519    -> Some `ED25519
  | _          -> None

and of_key_type = function
  | `RSA    -> RSA
  | `EC curve -> EC_pub curve
  | `ED25519 -> ED25519

(* XXX: No MD2 / MD4 / RIPEMD160 *)
and to_signature_algorithm = function
  | MD5_RSA -> Some (`RSA_PKCS1, `MD5)
  | SHA1_RSA -> Some (`RSA_PKCS1, `SHA1)
  | SHA256_RSA -> Some (`RSA_PKCS1, `SHA256)
  | SHA384_RSA -> Some (`RSA_PKCS1, `SHA384)
  | SHA512_RSA -> Some (`RSA_PKCS1, `SHA512)
  | SHA224_RSA -> Some (`RSA_PKCS1, `SHA224)
  | ECDSA_SHA1 -> Some (`ECDSA, `SHA1)
  | ECDSA_SHA224 -> Some (`ECDSA, `SHA224)
  | ECDSA_SHA256 -> Some (`ECDSA, `SHA256)
  | ECDSA_SHA384 -> Some (`ECDSA, `SHA384)
  | ECDSA_SHA512 -> Some (`ECDSA, `SHA512)
  | ED25519 -> Some (`ED25519, `SHA512)
  | _ -> None

and of_signature_algorithm public_key_algorithm digest =
  match public_key_algorithm, digest with
  | (`RSA_PKCS1, `MD5) -> MD5_RSA
  | (`RSA_PKCS1, `SHA1) -> SHA1_RSA
  | (`RSA_PKCS1, `SHA256) -> SHA256_RSA
  | (`RSA_PKCS1, `SHA384) -> SHA384_RSA
  | (`RSA_PKCS1, `SHA512) -> SHA512_RSA
  | (`RSA_PKCS1, `SHA224) -> SHA224_RSA
  | (`ECDSA, `SHA1)   -> ECDSA_SHA1
  | (`ECDSA, `SHA224) -> ECDSA_SHA224
  | (`ECDSA, `SHA256) -> ECDSA_SHA256
  | (`ECDSA, `SHA384) -> ECDSA_SHA384
  | (`ECDSA, `SHA512) -> ECDSA_SHA512
  | (`ED25519, _) -> ED25519
  | _ -> failwith "unsupported signature scheme and hash"

(* XXX
 *
 * PKCS1/RFC5280 allows params to be `ANY', depending on the algorithm.  I don't
 * know of one that uses anything other than NULL and OID, however, so we accept
 * only that.

   RFC 3279 Section 2.2.1 defines for RSA Signature Algorithms SHALL have null
   as parameter, but certificates in the wild don't contain the parameter field
   at all (it is optional). We accept both, and output a null paramter.
   Section 2.2.2 specifies DSA to have a null parameter,
   Section 2.2.3 specifies ECDSA to have a null parameter,
   Section 2.3.1 specifies rsaEncryption (for RSA public keys) requires null.
*)

let curve_of_oid, curve_to_oid =
  let open Registry.ANSI_X9_62 in
  (let default oid = Asn.(S.parse_error "Unknown algorithm %a" OID.pp oid) in
   case_of_oid ~default [
     (secp256r1, `SECP256R1) ;
     (secp384r1, `SECP384R1) ;
     (secp521r1, `SECP521R1) ;
   ]),
  (function
    | `SECP256R1 -> secp256r1
    | `SECP384R1 -> secp384r1
    | `SECP521R1 -> secp521r1)

let identifier =
  let open Registry in

  let f =
    let none x = function
      | None -> x
      | _    -> parse_error "Algorithm: expected no parameters"
    and null x = function
      | Some (`C1 ()) -> x
      | _             -> parse_error "Algorithm: expected null parameters"
    and null_or_none x = function
      | None | Some (`C1 ()) -> x
      | _                    -> parse_error "Algorithm: expected null or none parameter"
    and oid f = function
      | Some (`C2 id) -> f id
      | _             -> parse_error "Algorithm: expected parameter OID"
    and pbe f = function
      | Some (`C3 `PBE pbe) -> f pbe
      | _                   -> parse_error "Algorithm: expected parameter PBE"
    and pbkdf2 f = function
      | Some (`C3 `PBKDF2 params) -> f params
      | _                         -> parse_error "Algorithm: expected parameter PBKDF2"
    and pbes2 f = function
      | Some (`C3 `PBES2 params) -> f params
      | _                        -> parse_error "Algorithm: expected parameter PBES2"
    and octets f = function
      | Some (`C4 salt) -> f salt
      | _               -> parse_error "Algorithm: expected parameter octet_string"
    and default oid = Asn.(S.parse_error "Unknown algorithm %a" OID.pp oid)
    in

    case_of_oid_f ~default [

      (ANSI_X9_62.ec_pub_key, oid (fun id -> EC_pub (curve_of_oid id))) ;

      (PKCS1.rsa_encryption          , null RSA                  ) ;
      (PKCS1.md2_rsa_encryption      , null_or_none MD2_RSA      ) ;
      (PKCS1.md4_rsa_encryption      , null_or_none MD4_RSA      ) ;
      (PKCS1.md5_rsa_encryption      , null_or_none MD5_RSA      ) ;
      (PKCS1.ripemd160_rsa_encryption, null_or_none RIPEMD160_RSA) ;
      (PKCS1.sha1_rsa_encryption     , null_or_none SHA1_RSA     ) ;
      (PKCS1.sha256_rsa_encryption   , null_or_none SHA256_RSA   ) ;
      (PKCS1.sha384_rsa_encryption   , null_or_none SHA384_RSA   ) ;
      (PKCS1.sha512_rsa_encryption   , null_or_none SHA512_RSA   ) ;
      (PKCS1.sha224_rsa_encryption   , null_or_none SHA224_RSA   ) ;

      (ANSI_X9_62.ecdsa_sha1         , none ECDSA_SHA1   ) ;
      (ANSI_X9_62.ecdsa_sha224       , none ECDSA_SHA224 ) ;
      (ANSI_X9_62.ecdsa_sha256       , none ECDSA_SHA256 ) ;
      (ANSI_X9_62.ecdsa_sha384       , none ECDSA_SHA384 ) ;
      (ANSI_X9_62.ecdsa_sha512       , none ECDSA_SHA512 ) ;

      (RFC8410.ed25519               , none ED25519 ) ;

      (md2                           , null MD2          ) ;
      (md4                           , null MD4          ) ;
      (md5                           , null MD5          ) ;
      (sha1                          , null SHA1         ) ;
      (sha256                        , null SHA256       ) ;
      (sha384                        , null SHA384       ) ;
      (sha512                        , null SHA512       ) ;
      (sha224                        , null SHA224       ) ;
      (sha512_224                    , null SHA512_224   ) ;
      (sha512_256                    , null SHA512_256   ) ;

      (PKCS2.hmac_sha1               , null HMAC_SHA1    );
      (PKCS2.hmac_sha224             , null HMAC_SHA224  );
      (PKCS2.hmac_sha256             , null HMAC_SHA256  );
      (PKCS2.hmac_sha384             , null HMAC_SHA384  );
      (PKCS2.hmac_sha512             , null HMAC_SHA512  );

      (PKCS5.aes128_cbc              , octets (fun iv -> AES128_CBC iv));
      (PKCS5.aes192_cbc              , octets (fun iv -> AES192_CBC iv));
      (PKCS5.aes256_cbc              , octets (fun iv -> AES256_CBC iv));

      (PKCS12.pbe_with_SHA_and_128Bit_RC4, pbe (fun (s, i) -> SHA_RC4_128 (s, i))) ;
      (PKCS12.pbe_with_SHA_and_40Bit_RC4, pbe (fun (s, i) -> SHA_RC4_40 (s, i))) ;
      (PKCS12.pbe_with_SHA_and_3_KeyTripleDES_CBC, pbe (fun (s, i) -> SHA_3DES_CBC (s, i))) ;
      (PKCS12.pbe_with_SHA_and_2_KeyTripleDES_CBC, pbe (fun (s, i) -> SHA_2DES_CBC (s, i))) ;
      (PKCS12.pbe_with_SHA_and_128Bit_RC2_CBC, pbe (fun (s, i) -> SHA_RC2_128_CBC (s, i))) ;
      (PKCS12.pbe_with_SHA_and_40Bit_RC2_CBC, pbe (fun (s, i) -> SHA_RC2_40_CBC (s, i))) ;

      (PKCS5.pbkdf2, pbkdf2 (fun (s, i, l, m) -> PBKDF2 (s, i, l, m))) ;
      (PKCS5.pbes2, pbes2 (fun (oid, oid') -> PBES2 (oid, oid')))
    ]

  and g =
    let none    = None
    and null    = Some (`C1 ())
    and oid  id = Some (`C2 id)
    and pbe (s, i) = Some (`C3 (`PBE (s, i)))
    and pbkdf2 (s, i, k, m) = Some (`C3 (`PBKDF2 (s, i, k, m)))
    and pbes2 (oid, oid') = Some (`C3 (`PBES2 (oid, oid')))
    and octets data = Some (`C4 data)
    in
    function
    | EC_pub id     -> (ANSI_X9_62.ec_pub_key , oid (curve_to_oid id))

    | RSA           -> (PKCS1.rsa_encryption           , null)
    | MD2_RSA       -> (PKCS1.md2_rsa_encryption       , null)
    | MD4_RSA       -> (PKCS1.md4_rsa_encryption       , null)
    | MD5_RSA       -> (PKCS1.md5_rsa_encryption       , null)
    | RIPEMD160_RSA -> (PKCS1.ripemd160_rsa_encryption , null)
    | SHA1_RSA      -> (PKCS1.sha1_rsa_encryption      , null)
    | SHA256_RSA    -> (PKCS1.sha256_rsa_encryption    , null)
    | SHA384_RSA    -> (PKCS1.sha384_rsa_encryption    , null)
    | SHA512_RSA    -> (PKCS1.sha512_rsa_encryption    , null)
    | SHA224_RSA    -> (PKCS1.sha224_rsa_encryption    , null)

    | ECDSA_SHA1    -> (ANSI_X9_62.ecdsa_sha1          , none)
    | ECDSA_SHA224  -> (ANSI_X9_62.ecdsa_sha224        , none)
    | ECDSA_SHA256  -> (ANSI_X9_62.ecdsa_sha256        , none)
    | ECDSA_SHA384  -> (ANSI_X9_62.ecdsa_sha384        , none)
    | ECDSA_SHA512  -> (ANSI_X9_62.ecdsa_sha512        , none)

    | ED25519       -> (RFC8410.ed25519                , none)

    | MD2           -> (md2                            , null)
    | MD4           -> (md4                            , null)
    | MD5           -> (md5                            , null)
    | SHA1          -> (sha1                           , null)
    | SHA256        -> (sha256                         , null)
    | SHA384        -> (sha384                         , null)
    | SHA512        -> (sha512                         , null)
    | SHA224        -> (sha224                         , null)
    | SHA512_224    -> (sha512_224                     , null)
    | SHA512_256    -> (sha512_256                     , null)

    | HMAC_SHA1     -> (PKCS2.hmac_sha1                , null)
    | HMAC_SHA224   -> (PKCS2.hmac_sha224              , null)
    | HMAC_SHA256   -> (PKCS2.hmac_sha256              , null)
    | HMAC_SHA384   -> (PKCS2.hmac_sha384              , null)
    | HMAC_SHA512   -> (PKCS2.hmac_sha512              , null)

    | AES128_CBC iv -> (PKCS5.aes128_cbc               , octets iv)
    | AES192_CBC iv -> (PKCS5.aes192_cbc               , octets iv)
    | AES256_CBC iv -> (PKCS5.aes256_cbc               , octets iv)

    | SHA_RC4_128 (s, i) -> (PKCS12.pbe_with_SHA_and_128Bit_RC4, pbe (s, i))
    | SHA_RC4_40 (s, i) -> (PKCS12.pbe_with_SHA_and_40Bit_RC4, pbe (s, i))
    | SHA_3DES_CBC (s, i) -> (PKCS12.pbe_with_SHA_and_3_KeyTripleDES_CBC, pbe (s, i))
    | SHA_2DES_CBC (s, i) -> (PKCS12.pbe_with_SHA_and_2_KeyTripleDES_CBC, pbe (s, i))
    | SHA_RC2_128_CBC (s, i) -> (PKCS12.pbe_with_SHA_and_128Bit_RC2_CBC, pbe (s, i))
    | SHA_RC2_40_CBC (s, i) -> (PKCS12.pbe_with_SHA_and_40Bit_RC2_CBC, pbe (s, i))

    | PBKDF2 (s, i, k, m) -> (PKCS5.pbkdf2, pbkdf2 (s, i, k, m))
    | PBES2 (oid, oid') -> (PKCS5.pbes2, pbes2 (oid, oid'))
  in

  fix (fun id ->
      let pbkdf2_or_pbe_or_pbes2_params =
        (* TODO PBKDF2 should support `C2 oid (saltSources) *)
        let f (salt, count, (* key_len, *) prf) =
          match salt, count, (* key_len, *) prf with
          | `C1 salt, Some count, (* None, *) None -> `PBE (salt, count)
          | `C1 salt, Some count, (* x, *) Some prf -> `PBKDF2 (salt, count, None, prf)
          | `C2 oid, None, (* None, *) Some oid' -> `PBES2 (oid, oid')
          | _ -> parse_error "bad parameters"
        and g = function
          | `PBE (salt, count) -> (`C1 salt, Some count, (* None, *) None)
          | `PBKDF2 (salt, count, _key_len, prf) -> (`C1 salt, Some count, (* key_len, *) Some prf)
          | `PBES2 (oid, oid') -> (`C2 oid, None, (* None, *) Some oid')
        in
        map f g @@
        sequence3
          (required ~label:"salt" (choice2 octet_string id))
          (optional ~label:"iteration count" int) (* modified - required for pbkdf2/pbes *)
          (* (optional ~label:"key length" int) (* should be there and optional *) *)
          (optional ~label:"prf" id) (* only present in pbkdf2 / pbes2 *)
      in
      map f g @@
      sequence2
        (required ~label:"algorithm" oid)
        (optional ~label:"params"
           (choice4 null oid pbkdf2_or_pbe_or_pbes2_params octet_string)))

let ecdsa_sig =
  sequence2
    (required ~label:"r" unsigned_integer)
    (required ~label:"s" unsigned_integer)

let ecdsa_sig_of_octets, ecdsa_sig_to_octets =
  projections_of Asn.der ecdsa_sig

let pp fmt x = Fmt.string fmt (to_string x)
