diff --git a/gm/sm2/dde-sm2.c b/gm/sm2/dde-sm2.c index 52272b8..ca27e98 100644 --- a/gm/sm2/dde-sm2.c +++ b/gm/sm2/dde-sm2.c @@ -8,7 +8,6 @@ #include #include #include -#include #include "dde-sm2.h" @@ -86,6 +85,166 @@ static char* get_private_key(EC_KEY *key) { return ret; } +/*openssl sm2 cipher evp using*/ +static int openssl_evp_sm2_encrypt(EC_KEY *ec_key, + const unsigned char *plain_text, size_t plain_len, + unsigned char *cipher_text, size_t *cipher_len) +{ + int ret = 0; + BIO *bp = NULL; + EVP_PKEY* public_evp_key = NULL; + EVP_PKEY_CTX *ctx = NULL; + + /*Check the user input.*/ + if (plain_text == NULL || plain_len == 0) { + ret = -1; + return ret; + } + + //OpenSSL_add_all_algorithms(); + bp = BIO_new(BIO_s_mem()); + if (bp == NULL) { + printf("BIO_new is failed.\n"); + ret = -1; + return ret; + } + + if (ec_key == NULL) { + ret = -1; + printf("open_public_key failed to PEM_read_bio_EC_PUBKEY Failed, ret=%d\n", ret); + goto finish; + } + public_evp_key = EVP_PKEY_new(); + if (public_evp_key == NULL) { + ret = -1; + printf("open_public_key EVP_PKEY_new failed\n"); + goto finish; + } + ret = EVP_PKEY_set1_EC_KEY(public_evp_key, ec_key); + if (ret != 1) { + ret = -1; + printf("EVP_PKEY_set1_EC_KEY failed\n"); + goto finish; + } + ret = EVP_PKEY_set_alias_type(public_evp_key, EVP_PKEY_SM2); + if (ret != 1) { + printf("EVP_PKEY_set_alias_type to EVP_PKEY_SM2 failed! ret = %d\n", ret); + ret = -1; + goto finish; + } + /*modifying a EVP_PKEY to use a different set of algorithms than the default.*/ + + /*do cipher.*/ + ctx = EVP_PKEY_CTX_new(public_evp_key, NULL); + if (ctx == NULL) { + ret = -1; + printf("EVP_PKEY_CTX_new failed\n"); + goto finish; + } + ret = EVP_PKEY_encrypt_init(ctx); + if (ret < 0) { + printf("sm2_pubkey_encrypt failed to EVP_PKEY_encrypt_init. ret = %d\n", ret); + goto finish; + } + ret = EVP_PKEY_encrypt(ctx, cipher_text, cipher_len, plain_text, plain_len); + if (ret < 0) { + printf("sm2_pubkey_encrypt failed to EVP_PKEY_encrypt. ret = %d\n", ret); + goto finish; + } + ret = 0; + +finish: + if (public_evp_key != NULL) + EVP_PKEY_free(public_evp_key); + if (ctx != NULL) + EVP_PKEY_CTX_free(ctx); + if (bp != NULL) + BIO_free(bp); + + return ret; +} + +/*openssl sm2 decrypt evp using*/ +static int openssl_evp_sm2_decrypt(EC_KEY* ec_key, + const unsigned char *cipher_text, size_t cipher_len, + unsigned char *plain_text, size_t *plain_len) +{ + int ret = 0; + size_t out_len = 0; + BIO *bp = NULL; + EVP_PKEY* private_evp_key = NULL; + EVP_PKEY_CTX *ctx = NULL; + + /*Check the user input.*/ + if (cipher_len == 0 || cipher_text == NULL) { + ret = -1; + return ret; + } + + //OpenSSL_add_all_algorithms(); + bp = BIO_new(BIO_s_mem()); + if (bp == NULL) { + printf("BIO_new is failed.\n"); + ret = -1; + return ret; + } + + if (ec_key == NULL) { + ret = -1; + printf("open_private_key failed to PEM_read_bio_ECPrivateKey Failed, ret=%d\n", ret); + goto finish; + } + private_evp_key = EVP_PKEY_new(); + if (private_evp_key == NULL) { + ret = -1; + printf("open_public_key EVP_PKEY_new failed\n"); + goto finish; + } + ret = EVP_PKEY_set1_EC_KEY(private_evp_key, ec_key); + if (ret != 1) { + ret = -1; + printf("EVP_PKEY_set1_EC_KEY failed\n"); + goto finish; + } + ret = EVP_PKEY_set_alias_type(private_evp_key, EVP_PKEY_SM2); + if (ret != 1) { + printf("EVP_PKEY_set_alias_type to EVP_PKEY_SM2 failed! ret = %d\n", ret); + ret = -1; + goto finish; + } + /*modifying a EVP_PKEY to use a different set of algorithms than the default.*/ + + /*do cipher.*/ + ctx = EVP_PKEY_CTX_new(private_evp_key, NULL); + if (ctx == NULL) { + ret = -1; + printf("EVP_PKEY_CTX_new failed\n"); + goto finish; + } + ret = EVP_PKEY_decrypt_init(ctx); + if (ret < 0) { + printf("sm2 private_key decrypt failed to EVP_PKEY_decrypt_init. ret = %d\n", ret); + goto finish; + } + + ret = EVP_PKEY_decrypt(ctx, plain_text, plain_len, cipher_text, cipher_len); + if (ret < 0) { + printf("sm2_prikey_decrypt failed to EVP_PKEY_decrypt. ret = %d\n", ret); + goto finish; + } + ret = 0; +finish: + if (private_evp_key != NULL) + EVP_PKEY_free(private_evp_key); + if (ctx != NULL) + EVP_PKEY_CTX_free(ctx); + if (bp != NULL) + BIO_free(bp); + + return ret; +} + + sm2_context* new_sm2_context() { EC_KEY* key = gen_ec_key(); if (key == NULL) { @@ -125,18 +284,18 @@ const char* get_sm2_private_key(sm2_context* context) { return context->private_key; } -int get_ciphertext_size(const sm2_context *context, size_t plen) { +int get_ciphertext_size(const sm2_context *context, const uint8_t *ptext, size_t plen) { size_t ret = 0; - if (1 == sm2_ciphertext_size(context->key, EVP_sm3(), plen, &ret)) { + if (0 == openssl_evp_sm2_encrypt(context->key, ptext, plen, NULL, &ret)) { return (int)ret; } return -1; } -int get_plaintext_size(const uint8_t *ctext, size_t clen) { +int get_plaintext_size(const sm2_context *context, const uint8_t *ctext, size_t clen) { size_t ret = 0; - if (1 == sm2_plaintext_size(ctext, clen, &ret)) { + if (0 == openssl_evp_sm2_decrypt(context->key, ctext, clen, NULL, &ret)) { return (int)ret; } @@ -144,7 +303,7 @@ int get_plaintext_size(const uint8_t *ctext, size_t clen) { } int encrypt(const sm2_context* context, const uint8_t *ptext, size_t psize, uint8_t *ctext, size_t csize) { - if (1 == sm2_encrypt(context->key, EVP_sm3(), ptext, psize, ctext, &csize)) { + if (0 == openssl_evp_sm2_encrypt(context->key, ptext, psize, ctext, &csize)) { return (int)csize; } @@ -152,7 +311,7 @@ int encrypt(const sm2_context* context, const uint8_t *ptext, size_t psize, uint } int decrypt(const sm2_context* context, const uint8_t *ctext, size_t clen, uint8_t *ptext, size_t psize) { - if (1 == sm2_decrypt(context->key, EVP_sm3(), ctext, clen, ptext, &psize)) { + if (0 == openssl_evp_sm2_decrypt(context->key, ctext, clen, ptext, &psize)) { return (int)psize; } diff --git a/gm/sm2/dde-sm2.h b/gm/sm2/dde-sm2.h index ea87f07..00ac672 100644 --- a/gm/sm2/dde-sm2.h +++ b/gm/sm2/dde-sm2.h @@ -16,8 +16,8 @@ void free_sm2_context(sm2_context *context); const char* get_sm2_public_key(sm2_context *context); const char* get_sm2_private_key(sm2_context *context); -int get_ciphertext_size(const sm2_context *context, size_t plen); -int get_plaintext_size(const uint8_t *ctext, size_t clen); +int get_ciphertext_size(const sm2_context *context, const uint8_t *ptext, size_t plen); +int get_plaintext_size(const sm2_context *context, const uint8_t *ctext, size_t clen); int encrypt(const sm2_context *context, const uint8_t *ptext, size_t psize, uint8_t *ctext, size_t csize); int decrypt(const sm2_context *context, const uint8_t *ctext, size_t csize, uint8_t *ptext, size_t psize); diff --git a/gm/sm2/sm2.go b/gm/sm2/sm2.go index 93a40d3..e235b8a 100644 --- a/gm/sm2/sm2.go +++ b/gm/sm2/sm2.go @@ -37,7 +37,7 @@ func (s *SM2Helper) Encrypt(p []byte) ([]byte, error) { if len(p) == 0 { return nil, fmt.Errorf("plaintext size is zero") } - size := C.get_ciphertext_size(s.context, C.size_t(len(p))) + size := C.get_ciphertext_size(s.context, (*C.uint8_t)(unsafe.Pointer(&p[0])), C.size_t(len(p))) if size <= 0 { return nil, fmt.Errorf("get ciphertext size failed") } @@ -56,7 +56,7 @@ func (s *SM2Helper) Decrypt(c []byte) ([]byte, error) { if len(c) == 0 { return nil, fmt.Errorf("ciphertext size is zero") } - size := C.get_plaintext_size((*C.uint8_t)(unsafe.Pointer(&c[0])), C.size_t(len(c))) + size := C.get_plaintext_size(s.context, (*C.uint8_t)(unsafe.Pointer(&c[0])), C.size_t(len(c))) if size <= 0 { return nil, fmt.Errorf("get plaintext size failed") } diff --git a/gm/sm4/sm4.go b/gm/sm4/sm4.go index 73b38ae..e77cf04 100644 --- a/gm/sm4/sm4.go +++ b/gm/sm4/sm4.go @@ -4,8 +4,33 @@ package sm4 -// #include -// #cgo pkg-config: openssl +/* +#cgo pkg-config: openssl +#include + +static void openssl_evp_sm4_cipher(const unsigned char *key, + unsigned char *out, + unsigned char *in, int inl, + int enc) { + int ret = 0; + EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); + if (ctx == NULL) { + return; + } + ret = EVP_CipherInit(ctx, EVP_sm4_ecb(), key, NULL, enc); + if (1 != ret) { + printf("EVP_CipherInit fail... ret = %d \n", ret); + EVP_CIPHER_CTX_free(ctx); + return; + } + ret = EVP_Cipher(ctx, out, in, inl); + if (1 != ret) { + printf("EVP_Cipher fail.. ret = %d \n", ret); + } + + EVP_CIPHER_CTX_free(ctx); +} +*/ import "C" import ( "crypto/cipher" @@ -21,7 +46,7 @@ const ( // A cipher is an instance of SM4 encryption using a particular key. type sm4Cipher struct { - key C.SM4_KEY + key []byte } // NewCipher creates and returns a new cipher.Block. @@ -35,16 +60,24 @@ func NewCipher(key []byte) (cipher.Block, error) { break } ret := &sm4Cipher{} - C.SM4_set_key((*C.uint8_t)(unsafe.Pointer(&key[0])), &ret.key) + + ret.key = make([]byte, k) + copy(ret.key, key) return ret, nil } func (c *sm4Cipher) BlockSize() int { return BlockSize } func (c *sm4Cipher) Encrypt(dst, src []byte) { - C.SM4_encrypt((*C.uint8_t)(unsafe.Pointer(&src[0])), (*C.uint8_t)(unsafe.Pointer(&dst[0])), &c.key) + C.openssl_evp_sm4_cipher((*C.uint8_t)(unsafe.Pointer(&c.key[0])), + (*C.uint8_t)(unsafe.Pointer(&dst[0])), + (*C.uint8_t)(unsafe.Pointer(&src[0])), C.int(len(src)), + 1) } func (c *sm4Cipher) Decrypt(dst, src []byte) { - C.SM4_decrypt((*C.uint8_t)(unsafe.Pointer(&src[0])), (*C.uint8_t)(unsafe.Pointer(&dst[0])), &c.key) + C.openssl_evp_sm4_cipher((*C.uint8_t)(unsafe.Pointer(&c.key[0])), + (*C.uint8_t)(unsafe.Pointer(&dst[0])), + (*C.uint8_t)(unsafe.Pointer(&src[0])), C.int(len(src)), + 0) }