From fd572c44dc0d390076fc8ea2f9f185bef4fb1380 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Fri, 21 Feb 2025 16:55:39 +0000 Subject: [PATCH] Pass fewer snippets to suspicious commands Signed-off-by: nigel brown --- src/codegate/pipeline/comment/output.py | 16 +++-- .../suspicious_commands.py | 36 ++++++---- tests/test_suspicious_commands.py | 67 +++++++++++++++++++ 3 files changed, 100 insertions(+), 19 deletions(-) diff --git a/src/codegate/pipeline/comment/output.py b/src/codegate/pipeline/comment/output.py index 44a2f1af..4583a659 100644 --- a/src/codegate/pipeline/comment/output.py +++ b/src/codegate/pipeline/comment/output.py @@ -12,8 +12,7 @@ ) from codegate.pipeline.base import PipelineContext from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep - -# from codegate.pipeline.suspicious_commands.suspicious_commands import check_suspicious_code +from codegate.pipeline.suspicious_commands.suspicious_commands import check_suspicious_code from codegate.storage import StorageEngine from codegate.utils.package_extractor import PackageExtractor @@ -53,10 +52,15 @@ async def _snippet_comment(self, snippet: CodeSnippet, context: PipelineContext) """Create a comment for a snippet""" comment = "" - # Remove this for now. We need to find a better place for it. - # comment, is_suspicious = await check_suspicious_code(snippet.code, snippet.language) - # if is_suspicious: - # comment += comment + if ( + snippet.filepath is None + and snippet.file_extension is None + and "filepath" not in snippet.code + and "existing code" not in snippet.code + ): + new_comment, is_suspicious = await check_suspicious_code(snippet.code, snippet.language) + if is_suspicious: + comment += new_comment snippet.libraries = PackageExtractor.extract_packages(snippet.code, snippet.language) diff --git a/src/codegate/pipeline/suspicious_commands/suspicious_commands.py b/src/codegate/pipeline/suspicious_commands/suspicious_commands.py index ca3c3e8e..45e5c30d 100644 --- a/src/codegate/pipeline/suspicious_commands/suspicious_commands.py +++ b/src/codegate/pipeline/suspicious_commands/suspicious_commands.py @@ -12,10 +12,13 @@ import numpy as np # Add this import import onnxruntime as ort +import structlog from codegate.config import Config from codegate.inference.inference_engine import LlamaCppInferenceEngine +logger = structlog.get_logger("codegate") + class SuspiciousCommands: """ @@ -123,22 +126,29 @@ async def check_suspicious_code(code, language=None): Returns: tuple: A comment string and a boolean indicating if the code is suspicious. """ + if language is None: + language = "code" + if language in [ + "python", + "javascript", + "typescript", + "go", + "rust", + "java", + ]: + logger.debug(f"Skipping suspicious command check for {language}") + return "", False + logger.debug("Checking code for suspicious commands") sc = SuspiciousCommands.get_instance() comment = "" class_, prob = await sc.classify_phrase(code) - if class_ == 1: + is_suspicious = class_ == 1 + if is_suspicious: liklihood = "possibly" if prob > 0.9: liklihood = "likely" - if language is None: - language = "code" - if language not in [ - "python", - "javascript", - "typescript", - "go", - "rust", - "java", - ]: - comment = f"{comment}\n\n🛡️ CodeGate: The {language} supplied is {liklihood} unsafe. Please check carefully!\n\n" # noqa: E501 - return comment, class_ == 1 + comment = f"{comment}\n\n🛡️ CodeGate: The {language} supplied is {liklihood} unsafe. Please check carefully!\n\n" # noqa: E501 + logger.info(f"Suspicious: {code}") + else: + logger.debug("Not Suspicious") + return comment, is_suspicious diff --git a/tests/test_suspicious_commands.py b/tests/test_suspicious_commands.py index ceafad6e..4840ece2 100644 --- a/tests/test_suspicious_commands.py +++ b/tests/test_suspicious_commands.py @@ -4,11 +4,13 @@ """ import csv import os +from unittest.mock import AsyncMock, patch import pytest from codegate.pipeline.suspicious_commands.suspicious_commands import ( SuspiciousCommands, + check_suspicious_code, ) try: @@ -189,3 +191,68 @@ async def test_classify_phrase_confident(sc): else: print(f"{command['cmd']} {prob} {prediction} 1") check_results(tp, tn, fp, fn) + + +@pytest.mark.asyncio +@patch("codegate.pipeline.suspicious_commands.suspicious_commands.SuspiciousCommands.get_instance") +async def test_check_suspicious_code_safe(mock_get_instance): + """ + Test check_suspicious_code with safe code. + """ + mock_instance = mock_get_instance.return_value + mock_instance.classify_phrase = AsyncMock(return_value=(0, 0.5)) + + code = "print('Hello, world!')" + comment, is_suspicious = await check_suspicious_code(code, "python") + + assert comment == "" + assert is_suspicious is False + + +@pytest.mark.asyncio +@patch("codegate.pipeline.suspicious_commands.suspicious_commands.SuspiciousCommands.get_instance") +async def test_check_suspicious_code_suspicious(mock_get_instance): + """ + Test check_suspicious_code with suspicious code. + """ + mock_instance = mock_get_instance.return_value + mock_instance.classify_phrase = AsyncMock(return_value=(1, 0.95)) + + code = "rm -rf /" + comment, is_suspicious = await check_suspicious_code(code, "bash") + + assert "🛡️ CodeGate: The bash supplied is likely unsafe." in comment + assert is_suspicious is True + + +@pytest.mark.asyncio +@patch("codegate.pipeline.suspicious_commands.suspicious_commands.SuspiciousCommands.get_instance") +async def test_check_suspicious_code_skipped_language(mock_get_instance): + """ + Test check_suspicious_code with a language that should be skipped. + """ + mock_instance = mock_get_instance.return_value + mock_instance.classify_phrase = AsyncMock() + + code = "print('Hello, world!')" + comment, is_suspicious = await check_suspicious_code(code, "python") + + assert comment == "" + assert is_suspicious is False + mock_instance.classify_phrase.assert_not_called() + + +@pytest.mark.asyncio +@patch("codegate.pipeline.suspicious_commands.suspicious_commands.SuspiciousCommands.get_instance") +async def test_check_suspicious_code_no_language(mock_get_instance): + """ + Test check_suspicious_code with no language specified. + """ + mock_instance = mock_get_instance.return_value + mock_instance.classify_phrase = AsyncMock(return_value=(1, 0.85)) + + code = "rm -rf /" + comment, is_suspicious = await check_suspicious_code(code) + + assert "🛡️ CodeGate: The code supplied is possibly unsafe." in comment + assert is_suspicious is True