Skip to content

Commit d3b1867

Browse files
author
Jiawen
committed
feat:encryption adds support for SM2, SM3, and SM4 #131
1 parent e6fefa5 commit d3b1867

File tree

7 files changed

+1535
-2
lines changed

7 files changed

+1535
-2
lines changed

cryptor/gm_example_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package cryptor_test
2+
3+
import (
4+
"encoding/hex"
5+
"fmt"
6+
7+
"github.com/duke-git/lancet/v2/cryptor"
8+
)
9+
10+
func ExampleSm3() {
11+
data := []byte("hello world")
12+
hash := cryptor.Sm3(data)
13+
14+
fmt.Println(hex.EncodeToString(hash))
15+
16+
// Output:
17+
// 44f0061e69fa6fdfc290c494654a05dc0c053da7e5c52b84ef93a9d67d3fff88
18+
}
19+
20+
func ExampleSm4EcbEncrypt() {
21+
key := []byte("1234567890abcdef") // 16 bytes key
22+
plaintext := []byte("hello world")
23+
24+
encrypted := cryptor.Sm4EcbEncrypt(plaintext, key)
25+
decrypted := cryptor.Sm4EcbDecrypt(encrypted, key)
26+
27+
fmt.Println(string(decrypted))
28+
29+
// Output:
30+
// hello world
31+
}
32+
33+
func ExampleSm4CbcEncrypt() {
34+
key := []byte("1234567890abcdef") // 16 bytes key
35+
plaintext := []byte("hello world")
36+
37+
encrypted := cryptor.Sm4CbcEncrypt(plaintext, key)
38+
decrypted := cryptor.Sm4CbcDecrypt(encrypted, key)
39+
40+
fmt.Println(string(decrypted))
41+
42+
// Output:
43+
// hello world
44+
}
45+
46+
func ExampleGenerateSm2Key() {
47+
// Generate SM2 key pair
48+
privateKey, err := cryptor.GenerateSm2Key()
49+
if err != nil {
50+
return
51+
}
52+
53+
plaintext := []byte("hello world")
54+
55+
// Encrypt with public key
56+
ciphertext, err := cryptor.Sm2Encrypt(&privateKey.PublicKey, plaintext)
57+
if err != nil {
58+
return
59+
}
60+
61+
// Decrypt with private key
62+
decrypted, err := cryptor.Sm2Decrypt(privateKey, ciphertext)
63+
if err != nil {
64+
return
65+
}
66+
67+
fmt.Println(string(decrypted))
68+
69+
// Output:
70+
// hello world
71+
}

cryptor/gm_sm2.go

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
package cryptor
2+
3+
import (
4+
"crypto/elliptic"
5+
"crypto/rand"
6+
"encoding/binary"
7+
"errors"
8+
"io"
9+
"math/big"
10+
)
11+
12+
// SM2 implements the Chinese SM2 elliptic curve public key algorithm.
13+
// SM2 is based on elliptic curve cryptography and provides encryption, decryption, signing and verification.
14+
//
15+
// Note: This implementation uses crypto/elliptic package methods (GenerateKey, ScalarBaseMult, ScalarMult, IsOnCurve)
16+
// which are marked as deprecated in Go 1.20+. These methods still work correctly and are widely used.
17+
// The //nolint:staticcheck directive suppresses deprecation warnings.
18+
// A future version may replace these with a custom elliptic curve implementation.
19+
20+
var (
21+
sm2P256 *sm2Curve
22+
sm2P256Params = &elliptic.CurveParams{Name: "sm2p256v1"}
23+
)
24+
25+
func init() {
26+
// SM2 curve parameters
27+
sm2P256Params.P, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16)
28+
sm2P256Params.N, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", 16)
29+
sm2P256Params.B, _ = new(big.Int).SetString("28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", 16)
30+
sm2P256Params.Gx, _ = new(big.Int).SetString("32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", 16)
31+
sm2P256Params.Gy, _ = new(big.Int).SetString("BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", 16)
32+
sm2P256Params.BitSize = 256
33+
34+
sm2P256 = &sm2Curve{sm2P256Params}
35+
}
36+
37+
type sm2Curve struct {
38+
*elliptic.CurveParams
39+
}
40+
41+
// Sm2PrivateKey represents an SM2 private key.
42+
type Sm2PrivateKey struct {
43+
D *big.Int
44+
PublicKey Sm2PublicKey
45+
}
46+
47+
// Sm2PublicKey represents an SM2 public key.
48+
type Sm2PublicKey struct {
49+
X, Y *big.Int
50+
}
51+
52+
// GenerateSm2Key generates a new SM2 private/public key pair.
53+
// Play: https://go.dev/play/p/bKYMqRLvIx3
54+
func GenerateSm2Key() (*Sm2PrivateKey, error) {
55+
priv, x, y, err := elliptic.GenerateKey(sm2P256, rand.Reader)
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
privateKey := &Sm2PrivateKey{
61+
D: new(big.Int).SetBytes(priv),
62+
PublicKey: Sm2PublicKey{
63+
X: x,
64+
Y: y,
65+
},
66+
}
67+
68+
return privateKey, nil
69+
}
70+
71+
// Sm2Encrypt encrypts plaintext using SM2 public key.
72+
// Returns ciphertext in the format: C1 || C3 || C2
73+
// C1 = kG (65 bytes in uncompressed format)
74+
// C3 = Hash(x2 || M || y2) (32 bytes for SM3)
75+
// C2 = M xor t (same length as plaintext)
76+
// Play: https://go.dev/play/p/bKYMqRLvIx3
77+
func Sm2Encrypt(pub *Sm2PublicKey, plaintext []byte) ([]byte, error) {
78+
if pub == nil || pub.X == nil || pub.Y == nil {
79+
return nil, errors.New("sm2: invalid public key")
80+
}
81+
82+
for {
83+
// Generate random k
84+
k, err := randFieldElement(sm2P256, rand.Reader)
85+
if err != nil {
86+
return nil, err
87+
}
88+
89+
// C1 = kG
90+
c1x, c1y := sm2P256.ScalarBaseMult(k.Bytes())
91+
92+
// kP = (x2, y2)
93+
x2, y2 := sm2P256.ScalarMult(pub.X, pub.Y, k.Bytes())
94+
95+
// Derive key using KDF
96+
kdfLen := len(plaintext)
97+
t := sm2KDF(append(toBytes(sm2P256, x2), toBytes(sm2P256, y2)...), kdfLen)
98+
99+
// Check if t is all zeros
100+
allZero := true
101+
for _, b := range t {
102+
if b != 0 {
103+
allZero = false
104+
break
105+
}
106+
}
107+
if allZero {
108+
continue
109+
}
110+
111+
// C2 = M xor t
112+
c2 := make([]byte, len(plaintext))
113+
for i := 0; i < len(plaintext); i++ {
114+
c2[i] = plaintext[i] ^ t[i]
115+
}
116+
117+
// C3 = Hash(x2 || M || y2)
118+
c3Input := append(toBytes(sm2P256, x2), plaintext...)
119+
c3Input = append(c3Input, toBytes(sm2P256, y2)...)
120+
c3 := Sm3(c3Input)
121+
122+
// Return C1 || C3 || C2
123+
c1 := sm2MarshalUncompressed(sm2P256, c1x, c1y)
124+
result := append(c1, c3...)
125+
result = append(result, c2...)
126+
127+
return result, nil
128+
}
129+
}
130+
131+
// Sm2Decrypt decrypts ciphertext using SM2 private key.
132+
// Expects ciphertext in the format: C1 || C3 || C2
133+
// Play: https://go.dev/play/p/bKYMqRLvIx3
134+
func Sm2Decrypt(priv *Sm2PrivateKey, ciphertext []byte) ([]byte, error) {
135+
if priv == nil || priv.D == nil {
136+
return nil, errors.New("sm2: invalid private key")
137+
}
138+
139+
// Parse C1 (65 bytes), C3 (32 bytes), C2 (remaining)
140+
if len(ciphertext) < 97 {
141+
return nil, errors.New("sm2: ciphertext too short")
142+
}
143+
144+
c1 := ciphertext[:65]
145+
c3 := ciphertext[65:97]
146+
c2 := ciphertext[97:]
147+
148+
// Parse C1
149+
c1x, c1y := sm2UnmarshalUncompressed(sm2P256, c1)
150+
if c1x == nil {
151+
return nil, errors.New("sm2: invalid C1 point")
152+
}
153+
154+
// Verify C1 is on curve
155+
if !sm2P256.IsOnCurve(c1x, c1y) {
156+
return nil, errors.New("sm2: C1 not on curve")
157+
}
158+
159+
// dC1 = (x2, y2)
160+
x2, y2 := sm2P256.ScalarMult(c1x, c1y, priv.D.Bytes())
161+
162+
// Derive key using KDF
163+
kdfLen := len(c2)
164+
t := sm2KDF(append(toBytes(sm2P256, x2), toBytes(sm2P256, y2)...), kdfLen)
165+
166+
// M = C2 xor t
167+
plaintext := make([]byte, len(c2))
168+
for i := 0; i < len(c2); i++ {
169+
plaintext[i] = c2[i] ^ t[i]
170+
}
171+
172+
// Verify C3 = Hash(x2 || M || y2)
173+
u := append(toBytes(sm2P256, x2), plaintext...)
174+
u = append(u, toBytes(sm2P256, y2)...)
175+
hash := Sm3(u)
176+
177+
for i := 0; i < len(c3); i++ {
178+
if c3[i] != hash[i] {
179+
return nil, errors.New("sm2: hash verification failed")
180+
}
181+
}
182+
183+
return plaintext, nil
184+
}
185+
186+
// SM2 KDF (Key Derivation Function)
187+
func sm2KDF(z []byte, klen int) []byte {
188+
limit := (klen + 31) / 32
189+
result := make([]byte, 0, limit*32)
190+
191+
for i := 1; i <= limit; i++ {
192+
counter := make([]byte, 4)
193+
binary.BigEndian.PutUint32(counter, uint32(i))
194+
hash := Sm3(append(z, counter...))
195+
result = append(result, hash...)
196+
}
197+
198+
return result[:klen]
199+
}
200+
201+
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
202+
byteLen := (curve.Params().BitSize + 7) / 8
203+
buf := make([]byte, byteLen)
204+
b := value.Bytes()
205+
copy(buf[byteLen-len(b):], b)
206+
return buf
207+
}
208+
209+
func sm2MarshalUncompressed(curve *sm2Curve, x, y *big.Int) []byte {
210+
byteLen := (curve.BitSize + 7) / 8
211+
ret := make([]byte, 1+2*byteLen)
212+
ret[0] = 4 // uncompressed point
213+
214+
xBytes := x.Bytes()
215+
copy(ret[1+byteLen-len(xBytes):], xBytes)
216+
yBytes := y.Bytes()
217+
copy(ret[1+2*byteLen-len(yBytes):], yBytes)
218+
219+
return ret
220+
}
221+
222+
func sm2UnmarshalUncompressed(curve *sm2Curve, data []byte) (*big.Int, *big.Int) {
223+
byteLen := (curve.BitSize + 7) / 8
224+
if len(data) != 1+2*byteLen {
225+
return nil, nil
226+
}
227+
if data[0] != 4 {
228+
return nil, nil
229+
}
230+
231+
x := new(big.Int).SetBytes(data[1 : 1+byteLen])
232+
y := new(big.Int).SetBytes(data[1+byteLen:])
233+
234+
return x, y
235+
}
236+
237+
func randFieldElement(c elliptic.Curve, rand io.Reader) (*big.Int, error) {
238+
params := c.Params()
239+
b := make([]byte, params.BitSize/8+8)
240+
_, err := io.ReadFull(rand, b)
241+
if err != nil {
242+
return nil, err
243+
}
244+
245+
k := new(big.Int).SetBytes(b)
246+
n := new(big.Int).Sub(params.N, big.NewInt(1))
247+
k.Mod(k, n)
248+
k.Add(k, big.NewInt(1))
249+
250+
return k, nil
251+
}

0 commit comments

Comments
 (0)