diff --git a/signer/plugin.go b/signer/plugin.go index 6d8615d7..6c0d353c 100644 --- a/signer/plugin.go +++ b/signer/plugin.go @@ -162,7 +162,7 @@ func (s *PluginSigner) generateSignature(ctx context.Context, desc ocispec.Descr logger := log.GetLogger(ctx) logger.Debug("Generating signature by plugin") genericSigner := GenericSigner{ - signer: &pluginPrimitiveSigner{ + signer: &PluginPrimitiveSigner{ ctx: ctx, plugin: s.plugin, keyID: s.keyID, @@ -325,8 +325,8 @@ func parseCertChain(certChain [][]byte) ([]*x509.Certificate, error) { return certs, nil } -// pluginPrimitiveSigner implements signature.Signer -type pluginPrimitiveSigner struct { +// PluginPrimitiveSigner implements signature.Signer +type PluginPrimitiveSigner struct { ctx context.Context plugin plugin.SignPlugin keyID string @@ -335,7 +335,7 @@ type pluginPrimitiveSigner struct { } // Sign signs the digest by calling the underlying plugin. -func (s *pluginPrimitiveSigner) Sign(payload []byte) ([]byte, []*x509.Certificate, error) { +func (s *PluginPrimitiveSigner) Sign(payload []byte) ([]byte, []*x509.Certificate, error) { // Execute plugin sign command. keySpec, err := proto.EncodeKeySpec(s.keySpec) if err != nil { @@ -372,6 +372,35 @@ func (s *pluginPrimitiveSigner) Sign(payload []byte) ([]byte, []*x509.Certificat // KeySpec returns the keySpec of a keyID by calling describeKey and do some // keySpec validation. -func (s *pluginPrimitiveSigner) KeySpec() (signature.KeySpec, error) { +func (s *PluginPrimitiveSigner) KeySpec() (signature.KeySpec, error) { return s.keySpec, nil } + +// NewPluginPrimitiveSigner creates a new PluginPrimitiveSigner that delegates +// signing to a plugin. This is used for dm-verity PKCS#7 signing where raw +// signature bytes are needed instead of JWS/COSE envelopes. +func NewPluginPrimitiveSigner(ctx context.Context, p plugin.SignPlugin, keyID string, keySpec signature.KeySpec, pluginConfig map[string]string) *PluginPrimitiveSigner { + return &PluginPrimitiveSigner{ + ctx: ctx, + plugin: p, + keyID: keyID, + keySpec: keySpec, + pluginConfig: pluginConfig, + } +} + +// GetKeySpecFromPlugin retrieves the key specification from a plugin by calling DescribeKey. +func GetKeySpecFromPlugin(ctx context.Context, p plugin.SignPlugin, keyID string, pluginConfig map[string]string) (signature.KeySpec, error) { + req := &plugin.DescribeKeyRequest{ + ContractVersion: plugin.ContractVersion, + KeyID: keyID, + PluginConfig: pluginConfig, + } + + resp, err := p.DescribeKey(ctx, req) + if err != nil { + return signature.KeySpec{}, err + } + + return proto.DecodeKeySpec(resp.KeySpec) +} diff --git a/signer/plugin_test.go b/signer/plugin_test.go index a98caa07..5d839d56 100644 --- a/signer/plugin_test.go +++ b/signer/plugin_test.go @@ -166,7 +166,7 @@ func (p *mockPlugin) GenerateEnvelope(ctx context.Context, req *proto.GenerateEn return nil, err } - primitivePluginSigner := &pluginPrimitiveSigner{ + primitivePluginSigner := &PluginPrimitiveSigner{ ctx: ctx, plugin: internalPluginSigner.plugin, keyID: internalPluginSigner.keyID, @@ -492,3 +492,53 @@ func basicSignTest(t *testing.T, ps *PluginSigner, envelopeType string, data []b } basicVerification(t, data, envelopeType, mockPlugin.certs[len(mockPlugin.certs)-1], &validMetadata) } + +func TestNewPluginPrimitiveSigner(t *testing.T) { + ctx := context.Background() + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + + s := NewPluginPrimitiveSigner(ctx, mp, "testKeyID", defaultKeySpec, nil) + + // verify KeySpec + ks, err := s.KeySpec() + if err != nil { + t.Fatalf("KeySpec() error: %v", err) + } + if ks != defaultKeySpec { + t.Fatalf("KeySpec() = %v, want %v", ks, defaultKeySpec) + } + + // verify Sign + sig, certs, err := s.Sign([]byte("payload")) + if err != nil { + t.Fatalf("Sign() error: %v", err) + } + if len(sig) == 0 { + t.Fatal("Sign() returned empty signature") + } + if len(certs) == 0 { + t.Fatal("Sign() returned no certificates") + } +} + +func TestGetKeySpecFromPlugin(t *testing.T) { + ctx := context.Background() + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + + got, err := GetKeySpecFromPlugin(ctx, mp, "testKeyID", nil) + if err != nil { + t.Fatalf("GetKeySpecFromPlugin() error: %v", err) + } + if got != defaultKeySpec { + t.Fatalf("GetKeySpecFromPlugin() = %v, want %v", got, defaultKeySpec) + } +} + +func TestGetKeySpecFromPlugin_Error(t *testing.T) { + ctx := context.Background() + mp := &mockPlugin{} + _, err := GetKeySpecFromPlugin(ctx, mp, "testKeyID", nil) + if err == nil { + t.Fatal("expected error for empty keySpec, got nil") + } +}