diff --git a/src/ast/__tests__/expression-extractor.test.ts b/src/ast/__tests__/expression-extractor.test.ts index 3e3238b0..9a7d03b7 100644 --- a/src/ast/__tests__/expression-extractor.test.ts +++ b/src/ast/__tests__/expression-extractor.test.ts @@ -1035,3 +1035,142 @@ describe("extract ClassInstanceCreationExpression correctly", () => { expect(ast).toEqual(expectedAst); }); }); + +describe("extract CastExpression correctly", () => { + it("extract CastExpression int to char correctly", () => { + const programStr = ` + class Test { + void test() { + char c = (char) 65; + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "LocalVariableDeclarationStatement", + localVariableType: "char", + variableDeclaratorList: [ + { + kind: "VariableDeclarator", + variableDeclaratorId: "c", + variableInitializer: { + kind: "CastExpression", + type: "char", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "65", + }, + location: expect.anything(), + }, + location: expect.anything(), + }, + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); + + it("extract CastExpression double to int correctly", () => { + const programStr = ` + class Test { + void test() { + int x = (int) 3.14; + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "LocalVariableDeclarationStatement", + localVariableType: "int", + variableDeclaratorList: [ + { + kind: "VariableDeclarator", + variableDeclaratorId: "x", + variableInitializer: { + kind: "CastExpression", + type: "int", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalFloatingPointLiteral", + value: "3.14", + } + }, + location: expect.anything(), + }, + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); +}); diff --git a/src/ast/__tests__/switch-statement-extractor.test.ts b/src/ast/__tests__/switch-statement-extractor.test.ts new file mode 100644 index 00000000..00ae21ea --- /dev/null +++ b/src/ast/__tests__/switch-statement-extractor.test.ts @@ -0,0 +1,435 @@ +import { parse } from "../parser"; +import { AST } from "../types/packages-and-modules"; + +describe("extract SwitchStatement correctly", () => { + it("extract SwitchStatement with case labels and statements correctly", () => { + const programStr = ` + class Test { + void test(int x) { + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + default: + System.out.println("Default"); + } + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [ + { + kind: "FormalParameter", + unannType: "int", + identifier: "x", + }, + ], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "SwitchStatement", + expression: { + kind: "ExpressionName", + name: "x", + location: expect.anything(), + }, + cases: [ + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "1", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"One"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "2", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"Two"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + { + kind: "SwitchCase", + labels: [ + { + kind: "DefaultLabel", + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"Default"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); + + it("extract SwitchStatement with fallthrough correctly", () => { + const programStr = ` + class Test { + void test(int x) { + switch (x) { + case 1: + case 2: + System.out.println("One or Two"); + break; + default: + System.out.println("Default"); + } + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [ + { + kind: "FormalParameter", + unannType: "int", + identifier: "x", + }, + ], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "SwitchStatement", + expression: { + kind: "ExpressionName", + name: "x", + location: expect.anything(), + }, + cases: [ + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "1", + }, + }, + }, + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "2", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"One or Two"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + { + kind: "SwitchCase", + labels: [ + { + kind: "DefaultLabel", + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"Default"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); + + it("extract SwitchStatement without default case correctly", () => { + const programStr = ` + class Test { + void test(int x) { + switch (x) { + case 1: + System.out.println("One"); + break; + } + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [ + { + kind: "FormalParameter", + unannType: "int", + identifier: "x", + }, + ], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "SwitchStatement", + expression: { + kind: "ExpressionName", + name: "x", + location: expect.anything(), + }, + cases: [ + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "1", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"One"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + console.log(JSON.stringify(ast, null, 2)); + expect(ast).toEqual(expectedAst); + }); +}); diff --git a/src/ast/astExtractor/expression-extractor.ts b/src/ast/astExtractor/expression-extractor.ts index 3301db6a..dcce959c 100644 --- a/src/ast/astExtractor/expression-extractor.ts +++ b/src/ast/astExtractor/expression-extractor.ts @@ -2,6 +2,7 @@ import { ArgumentListCtx, BaseJavaCstVisitorWithDefaults, BinaryExpressionCtx, + CastExpressionCtx, ClassOrInterfaceTypeToInstantiateCtx, BooleanLiteralCtx, ExpressionCstNode, @@ -86,6 +87,62 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults { } } + castExpression(ctx: CastExpressionCtx) { + // Handle primitive cast expressions + if (ctx.primitiveCastExpression && ctx.primitiveCastExpression?.length > 0) { + const primitiveCast = ctx.primitiveCastExpression[0]; + const type = this.extractType(primitiveCast.children.primitiveType[0]); + const expression = this.visit(primitiveCast.children.unaryExpression[0]); + return { + kind: "CastExpression", + type: type, + expression: expression, + location: this.location, + }; + } + + throw new Error("Invalid CastExpression format."); + } + + private extractType(typeCtx: any): string { + // Check for the 'primitiveType' node + if (typeCtx.name === "primitiveType" && typeCtx.children) { + const { children } = typeCtx; + + // Handle 'numericType' (e.g., int, char, float, double) + if (children.numericType) { + const numericTypeCtx = children.numericType[0]; + + if (numericTypeCtx.children.integralType) { + // Handle integral types (e.g., char, int) + const integralTypeCtx = numericTypeCtx.children.integralType[0]; + + // Extract the specific type (e.g., 'char', 'int') + for (const key in integralTypeCtx.children) { + if (integralTypeCtx.children[key][0].image) { + return integralTypeCtx.children[key][0].image; + } + } + } + + if (numericTypeCtx.children.floatingPointType) { + // Handle floating-point types (e.g., float, double) + const floatingPointTypeCtx = numericTypeCtx.children.floatingPointType[0]; + + // Extract the specific type (e.g., 'float', 'double') + for (const key in floatingPointTypeCtx.children) { + if (floatingPointTypeCtx.children[key][0].image) { + return floatingPointTypeCtx.children[key][0].image; + } + } + } + } + } + + throw new Error("Invalid type context in cast expression."); + } + + private makeBinaryExpression( operators: IToken[], operands: UnaryExpressionCstNode[] @@ -174,6 +231,10 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults { } unaryExpression(ctx: UnaryExpressionCtx) { + if (ctx.primary[0].children.primaryPrefix[0].children.castExpression) { + return this.visit(ctx.primary[0].children.primaryPrefix[0].children.castExpression); + } + const node = this.visit(ctx.primary); if (ctx.UnaryPrefixOperator) { return { diff --git a/src/ast/astExtractor/statement-extractor.ts b/src/ast/astExtractor/statement-extractor.ts index ac2918cf..9e9f0e4c 100644 --- a/src/ast/astExtractor/statement-extractor.ts +++ b/src/ast/astExtractor/statement-extractor.ts @@ -21,6 +21,10 @@ import { PrimaryPrefixCtx, PrimarySuffixCtx, ReturnStatementCtx, + SwitchStatementCtx, + SwitchBlockCtx, + SwitchLabelCtx, + SwitchBlockStatementGroupCtx, StatementCstNode, StatementExpressionCtx, StatementWithoutTrailingSubstatementCtx, @@ -32,13 +36,17 @@ import { ExpressionStatementCtx, LocalVariableTypeCtx, VariableDeclaratorListCtx, - VariableDeclaratorCtx, -} from "java-parser"; + VariableDeclaratorCtx +} from 'java-parser' import { BasicForStatement, ExpressionStatement, IfStatement, MethodInvocation, + SwitchStatement, + SwitchCase, + CaseLabel, + DefaultLabel, Statement, StatementExpression, VariableDeclarator, @@ -80,6 +88,8 @@ export class StatementExtractor extends BaseJavaCstVisitorWithDefaults { return { kind: "BreakStatement" }; } else if (ctx.continueStatement) { return { kind: "ContinueStatement" }; + } else if (ctx.switchStatement) { + return this.visit(ctx.switchStatement); } else if (ctx.returnStatement) { const returnStatementExp = this.visit(ctx.returnStatement); return { @@ -90,6 +100,122 @@ export class StatementExtractor extends BaseJavaCstVisitorWithDefaults { } } + switchStatement(ctx: SwitchStatementCtx): SwitchStatement { + const expressionExtractor = new ExpressionExtractor(); + + return { + kind: "SwitchStatement", + expression: expressionExtractor.extract(ctx.expression[0]), + cases: ctx.switchBlock + ? this.visit(ctx.switchBlock) + : [], + location: ctx.Switch[0] + }; + } + + switchBlock(ctx: SwitchBlockCtx): Array { + const cases: Array = []; + let currentCase: SwitchCase; + + ctx.switchBlockStatementGroup?.forEach((group) => { + const extractedCase = this.visit(group); + + if (!currentCase) { + // First case in the switch block + currentCase = extractedCase; + cases.push(currentCase); + } else if (currentCase.statements && currentCase.statements.length === 0) { + // Fallthrough case, merge labels + currentCase.labels.push(...extractedCase.labels); + } else { + // New case with statements starts, push previous case and start new one + currentCase = extractedCase; + cases.push(currentCase); + } + }); + + return cases; + } + + switchBlockStatementGroup(ctx: SwitchBlockStatementGroupCtx): SwitchCase { + const blockStatementExtractor = new BlockStatementExtractor(); + + console.log(ctx.switchLabel) + + return { + kind: "SwitchCase", + labels: ctx.switchLabel.flatMap((label) => this.visit(label)), + statements: ctx.blockStatements + ? ctx.blockStatements.flatMap((blockStatements) => + blockStatements.children.blockStatement.map((stmt) => + blockStatementExtractor.extract(stmt) + ) + ) + : [], + }; + } + + // switchLabel(ctx: SwitchLabelCtx): CaseLabel | DefaultLabel { + // // Check if the context contains a "case" label + // if (ctx.caseOrDefaultLabel?.[0]?.children?.Case) { + // const expressionExtractor = new ExpressionExtractor(); + // // @ts-ignore + // const expressionCtx = ctx.caseOrDefaultLabel[0].children.caseLabelElement[0] + // .children.caseConstant[0].children.ternaryExpression[0].children; + // + // // Ensure the expression context is valid before proceeding + // if (!expressionCtx) { + // throw new Error("Invalid Case expression in switch label"); + // } + // + // const expression = expressionExtractor.ternaryExpression(expressionCtx); + // + // return { + // kind: "CaseLabel", + // expression: expression, + // }; + // } + // + // // Check if the context contains a "default" label + // if (ctx.caseOrDefaultLabel?.[0]?.children?.Default) { + // return { kind: "DefaultLabel" }; + // } + // + // // Throw an error if the context does not match expected patterns + // throw new Error("Invalid switch label: Neither 'case' nor 'default' found"); + // } + + switchLabel(ctx: SwitchLabelCtx): Array { + const expressionExtractor = new ExpressionExtractor(); + const labels: Array = []; + + // Process all case or default labels + for (const labelCtx of ctx.caseOrDefaultLabel) { + if (labelCtx.children.Case) { + // Extract the expression for the case label + const expressionCtx = labelCtx.children.caseLabelElement?.[0] + ?.children.caseConstant?.[0]?.children.ternaryExpression?.[0]?.children; + + if (!expressionCtx) { + throw new Error("Invalid Case expression in switch label"); + } + + labels.push({ + kind: "CaseLabel", + expression: expressionExtractor.ternaryExpression(expressionCtx), + }); + } else if (labelCtx.children.Default) { + labels.push({ kind: "DefaultLabel" }); + } + } + + if (labels.length === 0) { + throw new Error("Invalid switch label: Neither 'case' nor 'default' found"); + } + + return labels; + } + expressionStatement(ctx: ExpressionStatementCtx): ExpressionStatement { const stmtExp = this.visit(ctx.statementExpression); return { diff --git a/src/ast/types/blocks-and-statements.ts b/src/ast/types/blocks-and-statements.ts index fe5dc7ad..54a1ce9d 100644 --- a/src/ast/types/blocks-and-statements.ts +++ b/src/ast/types/blocks-and-statements.ts @@ -28,7 +28,8 @@ export type Statement = | IfStatement | WhileStatement | ForStatement - | EmptyStatement; + | EmptyStatement + | SwitchStatement; export interface EmptyStatement extends BaseNode { kind: "EmptyStatement"; @@ -66,6 +67,34 @@ export interface EnhancedForStatement extends BaseNode { kind: "EnhancedForStatement"; } +export interface SwitchStatement extends BaseNode { + kind: "SwitchStatement"; + expression: Expression; // The expression to evaluate for the switch + cases: Array; +} + +export interface SwitchCase extends BaseNode { + kind: "SwitchCase"; + labels: Array; // Labels for case blocks + statements?: Array; // Statements to execute for the case +} + +export type CaseLabel = CaseLiteralLabel | CaseExpressionLabel; + +export interface CaseLiteralLabel extends BaseNode { + kind: "CaseLabel"; + expression: Literal; // Literal values: byte, short, int, char, or String +} + +export interface CaseExpressionLabel extends BaseNode { + kind: "CaseLabel"; + expression: Expression; // For future extension if needed +} + +export interface DefaultLabel extends BaseNode { + kind: "DefaultLabel"; // Represents the default case +} + export type StatementWithoutTrailingSubstatement = | Block | ExpressionStatement @@ -259,7 +288,7 @@ export interface Assignment extends BaseNode { } export type LeftHandSide = ExpressionName | ArrayAccess; -export type UnaryExpression = PrefixExpression | PostfixExpression; +export type UnaryExpression = PrefixExpression | PostfixExpression | CastExpression; export interface PrefixExpression extends BaseNode { kind: "PrefixExpression"; @@ -289,3 +318,9 @@ export interface TernaryExpression extends BaseNode { consequent: Expression; alternate: Expression; } + +export interface CastExpression extends BaseNode { + kind: "CastExpression"; + type: UnannType; + expression: Expression; +} diff --git a/src/compiler/__tests__/index.ts b/src/compiler/__tests__/index.ts index 8f31ef5c..269d5d8f 100644 --- a/src/compiler/__tests__/index.ts +++ b/src/compiler/__tests__/index.ts @@ -1,25 +1,31 @@ -import { printlnTest } from "./tests/println.test"; -import { variableDeclarationTest } from "./tests/variableDeclaration.test"; -import { arithmeticExpressionTest } from "./tests/arithmeticExpression.test"; -import { ifElseTest } from "./tests/ifElse.test"; -import { whileTest } from "./tests/while.test"; -import { forTest } from "./tests/for.test"; -import { unaryExpressionTest } from "./tests/unaryExpression.test"; -import { methodInvocationTest } from "./tests/methodInvocation.test"; -import { importTest } from "./tests/import.test"; -import { arrayTest } from "./tests/array.test"; -import { classTest } from "./tests/class.test"; +import { printlnTest } from './tests/println.test' +import { variableDeclarationTest } from './tests/variableDeclaration.test' +import { arithmeticExpressionTest } from './tests/arithmeticExpression.test' +import { ifElseTest } from './tests/ifElse.test' +import { whileTest } from './tests/while.test' +import { forTest } from './tests/for.test' +import { unaryExpressionTest } from './tests/unaryExpression.test' +import { methodInvocationTest } from './tests/methodInvocation.test' +import { importTest } from './tests/import.test' +import { arrayTest } from './tests/array.test' +import { classTest } from './tests/class.test' +import { assignmentExpressionTest } from './tests/assignmentExpression.test' +import { castExpressionTest } from './tests/castExpression.test' +import { switchTest } from './tests/switch.test' -describe("compiler tests", () => { - printlnTest(); - variableDeclarationTest(); - arithmeticExpressionTest(); - unaryExpressionTest(); - ifElseTest(); - whileTest(); - forTest(); - methodInvocationTest(); - importTest(); - arrayTest(); - classTest(); -}) \ No newline at end of file +describe('compiler tests', () => { + switchTest() + castExpressionTest() + printlnTest() + variableDeclarationTest() + arithmeticExpressionTest() + unaryExpressionTest() + ifElseTest() + whileTest() + forTest() + methodInvocationTest() + importTest() + arrayTest() + classTest() + assignmentExpressionTest() +}) diff --git a/src/compiler/__tests__/tests/arithmeticExpression.test.ts b/src/compiler/__tests__/tests/arithmeticExpression.test.ts index abe02048..72a38c1d 100644 --- a/src/compiler/__tests__/tests/arithmeticExpression.test.ts +++ b/src/compiler/__tests__/tests/arithmeticExpression.test.ts @@ -78,6 +78,58 @@ const testCases: testCase[] = [ expectedLines: ["-2147483648", "-32769", "-32768", "-129", "-128", "-1", "0", "1", "127", "128", "32767", "32768", "2147483647"], }, + { + comment: "Mixed int and float addition (order swapped)", + program: ` + public class Main { + public static void main(String[] args) { + int a = 5; + float b = 2.5f; + System.out.println(a + b); + } + } + `, + expectedLines: ["7.5"], + }, + { + comment: "Mixed long and double multiplication", + program: ` + public class Main { + public static void main(String[] args) { + double a = 3.5; + long b = 10L; + System.out.println(a * b); + } + } + `, + expectedLines: ["35.0"], + }, + { + comment: "Mixed long and double multiplication (order swapped)", + program: ` + public class Main { + public static void main(String[] args) { + long a = 10L; + double b = 3.5; + System.out.println(a * b); + } + } + `, + expectedLines: ["35.0"], + }, + { + comment: "Mixed int and double division", + program: ` + public class Main { + public static void main(String[] args) { + double a = 2.0; + int b = 5; + System.out.println(a / b); + } + } + `, + expectedLines: ["0.4"], + } ]; export const arithmeticExpressionTest = () => describe("arithmetic expression", () => { diff --git a/src/compiler/__tests__/tests/assignmentExpression.test.ts b/src/compiler/__tests__/tests/assignmentExpression.test.ts new file mode 100644 index 00000000..5950de37 --- /dev/null +++ b/src/compiler/__tests__/tests/assignmentExpression.test.ts @@ -0,0 +1,124 @@ +import { + runTest, + testCase, +} from "../__utils__/test-utils"; + +const testCases: testCase[] = [ + { + comment: "int to double assignment", + program: ` + public class Main { + public static void main(String[] args) { + int x = 5; + double y = x; + System.out.println(y); + } + } + `, + expectedLines: ["5.0"], + }, + { + comment: "int to double conversion", + program: ` + public class Main { + public static void main(String[] args) { + int x = 5; + double y; + y = x; + System.out.println(y); + } + } + `, + expectedLines: ["5.0"], + }, + { + comment: "int to double conversion, array", + program: ` + public class Main { + public static void main(String[] args) { + int x = 6; + double[] y = {1.0, 2.0, 3.0, 4.0, 5.0}; + y[1] = x; + System.out.println(y[1]); + } + } + `, + expectedLines: ["6.0"], + }, + { + comment: "int to long", + program: ` + public class Main { + public static void main(String[] args) { + int a = 123; + long b = a; + System.out.println(b); + } + } + `, + expectedLines: ["123"], + }, + { + comment: "int to float", + program: ` + public class Main { + public static void main(String[] args) { + int a = 123; + float b = a; + System.out.println(b); + } + } + `, + expectedLines: ["123.0"], + }, + + // long -> other types + { + comment: "long to float", + program: ` + public class Main { + public static void main(String[] args) { + long a = 9223372036854775807L; + float b = a; + System.out.println(b); + } + } + `, + expectedLines: ["9.223372E18"], + }, + { + comment: "long to double", + program: ` + public class Main { + public static void main(String[] args) { + long a = 9223372036854775807L; + double b = a; + System.out.println(b); + } + } + `, + expectedLines: ["9.223372036854776E18"], + }, + + // float -> other types + { + comment: "float to double", + program: ` + public class Main { + public static void main(String[] args) { + float a = 3.0f; + double b = a; + System.out.println(b); + } + } + `, + expectedLines: ["3.0"], + }, +]; + +export const assignmentExpressionTest = () => describe("assignment expression", () => { + for (let testCase of testCases) { + const { comment: comment, program: program, expectedLines: expectedLines } = testCase; + it(comment, () => runTest(program, expectedLines)); + } +}); diff --git a/src/compiler/__tests__/tests/castExpression.test.ts b/src/compiler/__tests__/tests/castExpression.test.ts new file mode 100644 index 00000000..e811ec66 --- /dev/null +++ b/src/compiler/__tests__/tests/castExpression.test.ts @@ -0,0 +1,145 @@ +import { + runTest, + testCase, +} from "../__utils__/test-utils"; + +const testCases: testCase[] = [ + { + comment: "Simple primitive casting: int to float", + program: ` + public class Main { + public static void main(String[] args) { + int a = 5; + float b = (float) a; + System.out.println(b); + } + } + `, + expectedLines: ["5.0"], + }, + { + comment: "Simple primitive casting: float to int", + program: ` + public class Main { + public static void main(String[] args) { + float a = 2.9f; + int b = (int) a; + System.out.println(b); + } + } + `, + expectedLines: ["2"], + }, + { + comment: "Primitive casting: double to long", + program: ` + public class Main { + public static void main(String[] args) { + double a = 123456789.987; + long b = (long) a; + System.out.println(b); + } + } + `, + expectedLines: ["123456789"], + }, + { + comment: "Primitive casting: long to byte", + program: ` + public class Main { + public static void main(String[] args) { + long a = 257; + byte b = (byte) a; + System.out.println(b); + } + } + `, + expectedLines: ["1"], // byte wraps around at 256 + }, + { + comment: "Primitive casting: char to int", + program: ` + public class Main { + public static void main(String[] args) { + char a = 'A'; + int b = (int) a; + System.out.println(b); + } + } + `, + expectedLines: ["65"], + }, + { + comment: "Primitive casting: int to char", + program: ` + public class Main { + public static void main(String[] args) { + int a = 65; + char b = (char) a; + System.out.println(b); + } + } + `, + expectedLines: ["A"], + }, + { + comment: "Primitive casting: int to char", + program: ` + public class Main { + public static void main(String[] args) { + int a = 66; + char b = (char) a; + System.out.println(b); + } + } + `, + expectedLines: ["B"], + }, + { + comment: "Primitive casting with loss of precision", + program: ` + public class Main { + public static void main(String[] args) { + double a = 123.456; + int b = (int) a; + System.out.println(b); + } + } + `, + expectedLines: ["123"], + }, + { + comment: "Primitive casting: float to short", + program: ` + public class Main { + public static void main(String[] args) { + float a = 32768.0f; + short b = (short) a; + System.out.println(b); + } + } + `, + expectedLines: ["-32768"], // short wraps around + }, + { + comment: "Chained casting: double to int to byte", + program: ` + public class Main { + public static void main(String[] args) { + double a = 258.99; + int b = (int) a; + byte c = (byte) b; + System.out.println(c); + } + } + `, + expectedLines: ["2"], // 258 -> byte wraps around + }, +]; + +export const castExpressionTest = () => describe("cast expression", () => { + for (let testCase of testCases) { + const { comment: comment, program: program, expectedLines: expectedLines } = testCase; + it(comment, () => runTest(program, expectedLines)); + } +}); \ No newline at end of file diff --git a/src/compiler/__tests__/tests/println.test.ts b/src/compiler/__tests__/tests/println.test.ts index 4e2e4635..9b5ebbea 100644 --- a/src/compiler/__tests__/tests/println.test.ts +++ b/src/compiler/__tests__/tests/println.test.ts @@ -98,6 +98,22 @@ const testCases: testCase[] = [ `, expectedLines: ["true", "false"], }, + { + comment: "println with concatenated arguments", + program: ` + public class Main { + public static void main(String[] args) { + System.out.println("Hello" + " " + "world!"); + System.out.println("This is an int: " + 123); + System.out.println("This is a float: " + 4.5f); + System.out.println("This is a long: " + 10000000000L); + System.out.println("This is a double: " + 10.3); + } + } + `, + expectedLines: ["Hello world!", "This is an int: 123", "This is a float: 4.5", + "This is a long: 10000000000", "This is a double: 10.3"], + }, { comment: "multiple println statements", program: ` diff --git a/src/compiler/__tests__/tests/switch.test.ts b/src/compiler/__tests__/tests/switch.test.ts new file mode 100644 index 00000000..003e9ff4 --- /dev/null +++ b/src/compiler/__tests__/tests/switch.test.ts @@ -0,0 +1,203 @@ +import { runTest, testCase } from '../__utils__/test-utils' + +const testCases: testCase[] = [ + { + comment: 'More basic switch case', + program: ` + public class Main { + public static void main(String[] args) { + int x = 1; + switch (x) { + case 1: + System.out.println("One"); + break; + } + } + } + `, + expectedLines: ['One'] + }, + { + comment: 'Basic switch case', + program: ` + public class Main { + public static void main(String[] args) { + int x = 2; + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + case 3: + System.out.println("Three"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Two'] + }, + { + comment: 'Switch with default case', + program: ` + public class Main { + public static void main(String[] args) { + int x = 5; + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Default'] + }, + { + comment: 'Switch fallthrough behavior', + program: ` + public class Main { + public static void main(String[] args) { + int x = 2; + switch (x) { + case 1: + System.out.println("One"); + case 2: + System.out.println("Two"); + case 3: + System.out.println("Three"); + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Two', 'Three', 'Default'] + }, + { + comment: 'Switch with break statements', + program: ` + public class Main { + public static void main(String[] args) { + int x = 3; + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + case 3: + System.out.println("Three"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Three'] + }, + { + comment: 'Switch with strings', + program: ` + public class Main { + public static void main(String[] args) { + String day = "Tuesday"; + switch (day) { + case "Monday": + System.out.println("Start of the week"); + break; + case "Tuesday": + System.out.println("Second day"); + break; + case "Friday": + System.out.println("Almost weekend"); + break; + default: + System.out.println("Midweek or weekend"); + } + } + } + `, + expectedLines: ['Second day'] + }, + { + comment: 'Nested switch statements', + program: ` + public class Main { + public static void main(String[] args) { + int outer = 2; + int inner = 1; + switch (outer) { + case 1: + switch (inner) { + case 1: + System.out.println("Inner One"); + break; + case 2: + System.out.println("Inner Two"); + break; + } + break; + case 2: + switch (inner) { + case 1: + System.out.println("Outer Two, Inner One"); + break; + case 2: + System.out.println("Outer Two, Inner Two"); + break; + } + break; + default: + System.out.println("Default case"); + } + } + } + `, + expectedLines: ['Outer Two, Inner One'] + }, + + { + comment: 'Switch with far apart cases', + program: ` + public class Main { + public static void main(String[] args) { + int x = 1331; + switch (x) { + case 1: + System.out.println("No"); + break; + case 1331: + System.out.println("Yes"); + break; + case 999999999: + System.out.println("No"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Yes'] + } +] + +export const switchTest = () => + describe('Switch statements', () => { + for (let testCase of testCases) { + const { comment, program, expectedLines } = testCase + it(comment, () => runTest(program, expectedLines)) + } + }) diff --git a/src/compiler/__tests__/tests/unaryExpression.test.ts b/src/compiler/__tests__/tests/unaryExpression.test.ts index 175c6a2d..0e9aa469 100644 --- a/src/compiler/__tests__/tests/unaryExpression.test.ts +++ b/src/compiler/__tests__/tests/unaryExpression.test.ts @@ -159,6 +159,45 @@ const testCases: testCase[] = [ }`, expectedLines: ["10", "10", "-10", "-10", "-10", "-10", "10", "9", "-10"], }, + { + comment: "unary plus/minus for long", + program: ` + public class Main { + public static void main(String[] args) { + long a = 9223372036854775807L; + System.out.println(+a); + System.out.println(-a); + } + } + `, + expectedLines: ["9223372036854775807", "-9223372036854775807"], + }, + { + comment: "unary plus/minus for float", + program: ` + public class Main { + public static void main(String[] args) { + float a = 4.5f; + System.out.println(+a); + System.out.println(-a); + } + } + `, + expectedLines: ["4.5", "-4.5"], + }, + { + comment: "unary plus/minus for double", + program: ` + public class Main { + public static void main(String[] args) { + double a = 10.75; + System.out.println(+a); + System.out.println(-a); + } + } + `, + expectedLines: ["10.75", "-10.75"], + }, { comment: "bitwise complement", program: ` diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 1b0d8c86..a59ba776 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -24,7 +24,7 @@ import { ClassInstanceCreationExpression, ExpressionStatement, TernaryExpression, - LeftHandSide + LeftHandSide, CastExpression, SwitchStatement, SwitchCase, CaseLabel } from '../ast/types/blocks-and-statements' import { MethodDeclaration, UnannType } from '../ast/types/classes' import { ConstantPoolManager } from './constant-pool-manager' @@ -164,12 +164,177 @@ const normalStoreOp: { [type: string]: OPCODE } = { Z: OPCODE.ISTORE } +const typeConversions: { [key: string]: OPCODE } = { + 'I->F': OPCODE.I2F, + 'I->D': OPCODE.I2D, + 'I->J': OPCODE.I2L, + 'I->B': OPCODE.I2B, + 'I->C': OPCODE.I2C, + 'I->S': OPCODE.I2S, + 'F->D': OPCODE.F2D, + 'F->I': OPCODE.F2I, + 'F->J': OPCODE.F2L, + 'D->F': OPCODE.D2F, + 'D->I': OPCODE.D2I, + 'D->J': OPCODE.D2L, + 'J->I': OPCODE.L2I, + 'J->F': OPCODE.L2F, + 'J->D': OPCODE.L2D +}; + +const typeConversionsImplicit: { [key: string]: OPCODE } = { + 'I->F': OPCODE.I2F, + 'I->D': OPCODE.I2D, + 'I->J': OPCODE.I2L, + 'F->D': OPCODE.F2D, + 'J->F': OPCODE.L2F, + 'J->D': OPCODE.L2D +} + type CompileResult = { stackSize: number resultType: string } const EMPTY_TYPE: string = '' +function areClassTypesCompatible(fromType: string, toType: string): boolean { + const cleanFrom = fromType.replace(/^L|;$/g, '') + const cleanTo = toType.replace(/^L|;$/g, '') + return cleanFrom === cleanTo +} + +function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator): number { + if (fromType === toType || toType.replace(/^L|;$/g, '') === 'java/lang/String') { + return 0; + } + + if (fromType.startsWith('L') || toType.startsWith('L')) { + if (areClassTypesCompatible(fromType, toType) || fromType === '') { + return 0; + } + throw new Error(`Unsupported class type conversion: ${fromType} -> ${toType}`) + } + + const conversionKey = `${fromType}->${toType}` + if (conversionKey in typeConversionsImplicit) { + cg.code.push(typeConversionsImplicit[conversionKey]) + if (!(fromType in ['J', 'D']) && toType in ['J', 'D']) { + return 1; + } else if (!(toType in ['J', 'D']) && fromType in ['J', 'D']) { + return -1; + } else { + return 0; + } + } else { + throw new Error(`Unsupported implicit type conversion: ${conversionKey}`) + } +} + +function handleExplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { + if (fromType === toType) { + return; + } + const conversionKey = `${fromType}->${toType}`; + if (conversionKey in typeConversions) { + cg.code.push(typeConversions[conversionKey]); + } else { + throw new Error(`Unsupported explicit type conversion: ${conversionKey}`); + } +} + +function generateStringConversion(valueType: string, cg: CodeGenerator): void { + const stringClass = 'java/lang/String'; + + // Map primitive types to `String.valueOf()` method descriptors + const valueOfDescriptors: { [key: string]: string } = { + I: '(I)Ljava/lang/String;', // int + J: '(J)Ljava/lang/String;', // long + F: '(F)Ljava/lang/String;', // float + D: '(D)Ljava/lang/String;', // double + Z: '(Z)Ljava/lang/String;', // boolean + B: '(B)Ljava/lang/String;', // byte + S: '(S)Ljava/lang/String;', // short + C: '(C)Ljava/lang/String;' // char + }; + + const descriptor = valueOfDescriptors[valueType]; + if (!descriptor) { + throw new Error(`Unsupported primitive type for String conversion: ${valueType}`); + } + + const methodIndex = cg.constantPoolManager.indexMethodrefInfo( + stringClass, + 'valueOf', + descriptor + ); + + cg.code.push(OPCODE.INVOKESTATIC, 0, methodIndex); +} + +function hashCode(str: string): number { + let hash = 0; + for (let i = 0; i < str.length; i++) { + hash = ((hash * 31) + str.charCodeAt(i)); // Simulate Java's overflow behavior + } + return hash; +} + +// function generateBooleanConversion(type: string, cg: CodeGenerator): number { +// let stackChange = 0; // Tracks changes to the stack size +// +// switch (type) { +// case 'I': // int +// case 'B': // byte +// case 'S': // short +// case 'C': // char +// // For integer-like types, compare with zero +// cg.code.push(OPCODE.ICONST_0); // Push 0 +// stackChange += 1; // `ICONST_0` pushes a value onto the stack +// cg.code.push(OPCODE.IF_ICMPNE); // Compare and branch +// stackChange -= 2; // `IF_ICMPNE` consumes two values from the stack +// break; +// +// case 'J': // long +// // For long, compare with zero +// cg.code.push(OPCODE.LCONST_0); // Push 0L +// stackChange += 2; // `LCONST_0` pushes two values onto the stack (long takes 2 slots) +// cg.code.push(OPCODE.LCMP); // Compare top two longs +// stackChange -= 4; // `LCMP` consumes four values (two long operands) and pushes one result +// cg.code.push(OPCODE.IFNE); // If not equal, branch +// stackChange -= 1; // `IFNE` consumes one value (the comparison result) +// break; +// +// case 'F': // float +// // For float, compare with zero +// cg.code.push(OPCODE.FCONST_0); // Push 0.0f +// stackChange += 1; // `FCONST_0` pushes a value onto the stack +// cg.code.push(OPCODE.FCMPL); // Compare top two floats +// stackChange -= 2; // `FCMPL` consumes two values (float operands) and pushes one result +// cg.code.push(OPCODE.IFNE); // If not equal, branch +// stackChange -= 1; // `IFNE` consumes one value (the comparison result) +// break; +// +// case 'D': // double +// // For double, compare with zero +// cg.code.push(OPCODE.DCONST_0); // Push 0.0d +// stackChange += 2; // `DCONST_0` pushes two values onto the stack (double takes 2 slots) +// cg.code.push(OPCODE.DCMPL); // Compare top two doubles +// stackChange -= 4; // `DCMPL` consumes four values (two double operands) and pushes one result +// cg.code.push(OPCODE.IFNE); // If not equal, branch +// stackChange -= 1; // `IFNE` consumes one value (the comparison result) +// break; +// +// case 'Z': // boolean +// // Already a boolean, no conversion needed +// break; +// +// default: +// throw new Error(`Cannot convert type ${type} to boolean.`); +// } +// +// return stackChange; // Return the net change in stack size +// } + const isNullLiteral = (node: Node) => { return node.kind === 'Literal' && node.literalType.kind === 'NullLiteral' } @@ -245,13 +410,16 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi vi.forEach((val, i) => { cg.code.push(OPCODE.DUP) const size1 = compile(createIntLiteralNode(i), cg).stackSize - const size2 = compile(val as Expression, cg).stackSize + const { stackSize: size2, resultType } = compile(val as Expression, cg) + const stackSizeChange = handleImplicitTypeConversion(resultType, arrayElemType, cg) cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) - maxStack = Math.max(maxStack, 2 + size1 + size2) + maxStack = Math.max(maxStack, 2 + size1 + size2 + stackSizeChange) }) cg.code.push(OPCODE.ASTORE, curIdx) } else { - maxStack = Math.max(maxStack, compile(vi, cg).stackSize) + const { stackSize: initializerStackSize, resultType: initializerType } = compile(vi, cg) + const stackSizeChange = handleImplicitTypeConversion(initializerType, variableInfo.typeDescriptor, cg) + maxStack = Math.max(maxStack, initializerStackSize + stackSizeChange) cg.code.push( variableInfo.typeDescriptor in normalStoreOp ? normalStoreOp[variableInfo.typeDescriptor] @@ -276,8 +444,16 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi }, BreakStatement: (node: Node, cg: CodeGenerator) => { - cg.addBranchInstr(OPCODE.GOTO, cg.loopLabels[cg.loopLabels.length - 1][1]) - return { stackSize: 0, resultType: EMPTY_TYPE } + if (cg.loopLabels.length > 0) { + // If inside a loop, break jumps to the end of the loop + cg.addBranchInstr(OPCODE.GOTO, cg.loopLabels[cg.loopLabels.length - 1][1]); + } else if (cg.switchLabels.length > 0) { + // If inside a switch, break jumps to the switch's end label + cg.addBranchInstr(OPCODE.GOTO, cg.switchLabels[cg.switchLabels.length - 1]); + } else { + throw new Error("Break statement not inside a loop or switch statement"); + } + return { stackSize: 0, resultType: EMPTY_TYPE }; }, ContinueStatement: (node: Node, cg: CodeGenerator) => { @@ -429,6 +605,11 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.addBranchInstr(OPCODE.GOTO, targetLabel) } return { stackSize: 0, resultType: cg.symbolTable.generateFieldDescriptor('boolean') } + } else { + if (onTrue === (parseInt(value) !== 0)) { + cg.addBranchInstr(OPCODE.GOTO, targetLabel) + } + return { stackSize: 0, resultType: cg.symbolTable.generateFieldDescriptor('boolean') } } } @@ -572,6 +753,10 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi let resultType = EMPTY_TYPE const symbolInfos = cg.symbolTable.queryMethod(n.identifier) + if (!symbolInfos || symbolInfos.length === 0) { + throw new Error(`Method not found: ${n.identifier}`) + } + for (let i = 0; i < symbolInfos.length - 1; i++) { if (i === 0) { const varInfo = symbolInfos[i] as VariableInfo @@ -594,10 +779,36 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } const argTypes: Array = [] + + const methodInfo = symbolInfos[symbolInfos.length - 1] as MethodInfos + if (!methodInfo || methodInfo.length === 0) { + throw new Error(`No method information found for ${n.identifier}`) + } + + const fullDescriptor = methodInfo[0].typeDescriptor // Full descriptor, e.g., "(Ljava/lang/String;C)V" + const paramDescriptor = fullDescriptor.slice(1, fullDescriptor.indexOf(')')) // Extract "Ljava/lang/String;C" + const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) + + // Parse individual parameter types + if (params && params.length !== n.argumentList.length) { + throw new Error( + `Parameter mismatch: expected ${params?.length || 0}, got ${n.argumentList.length}` + ) + } + n.argumentList.forEach((x, i) => { const argCompileResult = compile(x, cg) - maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize) - argTypes.push(argCompileResult.resultType) + + let normalizedType = argCompileResult.resultType; + if (normalizedType === 'B' || normalizedType === 'S') { + normalizedType = 'I' + } + + const expectedType = params?.[i] // Expected parameter type + const stackSizeChange = handleImplicitTypeConversion(normalizedType, expectedType ?? '', cg) + maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize + stackSizeChange) + + argTypes.push(normalizedType) }) const argDescriptor = '(' + argTypes.join('') + ')' @@ -632,7 +843,9 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } if (!foundMethod) { - throw new InvalidMethodCallError(n.identifier) + throw new InvalidMethodCallError( + `No method matching signature ${n.identifier}${argDescriptor} found.` + ) } return { stackSize: maxStack, resultType: resultType } }, @@ -662,15 +875,20 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi if (lhs.kind === 'ArrayAccess') { const { stackSize: size1, resultType: arrayType } = compile(lhs.primary, cg) const size2 = compile(lhs.expression, cg).stackSize - maxStack = size1 + size2 + compile(right, cg).stackSize + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) + const arrayElemType = arrayType.slice(1) + const stackSizeChange = handleImplicitTypeConversion(rhsType, arrayElemType, cg) + maxStack = Math.max(maxStack, size1 + size2 + rhsSize + stackSizeChange) cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) } else if ( lhs.kind === 'ExpressionName' && !Array.isArray(cg.symbolTable.queryVariable(lhs.name)) ) { const info = cg.symbolTable.queryVariable(lhs.name) as VariableInfo - maxStack = 1 + compile(right, cg).stackSize + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) + const stackSizeChange = handleImplicitTypeConversion(rhsType, info.typeDescriptor, cg) + maxStack = Math.max(maxStack, 1 + rhsSize + stackSizeChange) cg.code.push( info.typeDescriptor in normalStoreOp ? normalStoreOp[info.typeDescriptor] : OPCODE.ASTORE, info.index @@ -693,7 +911,11 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.code.push(OPCODE.ALOAD, 0) maxStack += 1 } - maxStack += compile(right, cg).stackSize + + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) + const stackSizeChange = handleImplicitTypeConversion(rhsType, fieldInfo.typeDescriptor, cg) + + maxStack = Math.max(maxStack, maxStack + rhsSize + stackSizeChange) cg.code.push( fieldInfo.accessFlags & FIELD_FLAGS.ACC_STATIC ? OPCODE.PUTSTATIC : OPCODE.PUTFIELD, 0, @@ -737,33 +959,107 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } } - const { stackSize: size1, resultType: type } = compile(left, cg) - const { stackSize: size2 } = compile(right, cg) + const { stackSize: size1, resultType: leftType } = compile(left, cg) + const insertConversionIndex = cg.code.length; + cg.code.push(OPCODE.NOP); + const { stackSize: size2, resultType: rightType } = compile(right, cg) + + if (op === '+' && + (leftType === 'Ljava/lang/String;' + || rightType === 'Ljava/lang/String;')) { + if (leftType !== 'Ljava/lang/String;') { + generateStringConversion(leftType, cg); + } + + if (rightType !== 'Ljava/lang/String;') { + generateStringConversion(rightType, cg); + } + + // Invoke `String.concat` for concatenation + const concatMethodIndex = cg.constantPoolManager.indexMethodrefInfo( + 'java/lang/String', + 'concat', + '(Ljava/lang/String;)Ljava/lang/String;' + ); + cg.code.push(OPCODE.INVOKEVIRTUAL, 0, concatMethodIndex); + + return { + stackSize: Math.max(size1 + 1, size2 + 1), // Max stack size plus one for the concatenation + resultType: 'Ljava/lang/String;' + }; + } + + let finalType = leftType; + + if (leftType !== rightType) { + const conversionKeyLeft = `${leftType}->${rightType}` + const conversionKeyRight = `${rightType}->${leftType}` + + if (['D', 'F'].includes(leftType) || ['D', 'F'].includes(rightType)) { + // Promote both to double if one is double, or to float otherwise + if (leftType !== 'D' && rightType === 'D') { + cg.code.fill(typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, insertConversionIndex + 1) + finalType = 'D'; + } else if (leftType === 'D' && rightType !== 'D') { + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + finalType = 'D'; + } else if (leftType !== 'F' && rightType === 'F') { + // handleImplicitTypeConversion(leftType, 'F', cg); + cg.code.fill(typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, insertConversionIndex + 1) + finalType = 'F'; + } else if (leftType === 'F' && rightType !== 'F') { + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + finalType = 'F'; + } + } else if (['J'].includes(leftType) || ['J'].includes(rightType)) { + // Promote both to long if one is long + if (leftType !== 'J' && rightType === 'J') { + cg.code.fill(typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, insertConversionIndex + 1) + } else if (leftType === 'J' && rightType !== 'J') { + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + } + finalType = 'J'; + } else { + // Promote both to int as the common type for smaller types like byte, short, char + if (leftType !== 'I') { + cg.code.fill(typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, insertConversionIndex + 1) + } + if (rightType !== 'I') { + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + } + finalType = 'I'; + } + } - switch (type) { + // Perform the operation + switch (finalType) { case 'B': - cg.code.push(intBinaryOp[op], OPCODE.I2B) - break + cg.code.push(intBinaryOp[op], OPCODE.I2B); + break; case 'D': - cg.code.push(doubleBinaryOp[op]) - break + cg.code.push(doubleBinaryOp[op]); + break; case 'F': - cg.code.push(floatBinaryOp[op]) - break + cg.code.push(floatBinaryOp[op]); + break; case 'I': - cg.code.push(intBinaryOp[op]) - break + cg.code.push(intBinaryOp[op]); + break; case 'J': - cg.code.push(longBinaryOp[op]) - break + cg.code.push(longBinaryOp[op]); + break; case 'S': - cg.code.push(intBinaryOp[op], OPCODE.I2S) - break + cg.code.push(intBinaryOp[op], OPCODE.I2S); + break; } return { - stackSize: Math.max(size1, 1 + (['D', 'J'].includes(type) ? 1 : 0) + size2), - resultType: type + stackSize: Math.max(size1, 1 + (['D', 'J'].includes(finalType) ? 1 : 0) + size2), + resultType: finalType } }, @@ -799,7 +1095,18 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const compileResult = compile(expr, cg) if (op === '-') { - cg.code.push(OPCODE.INEG) + const negationOpcodes: { [type: string]: OPCODE } = { + I: OPCODE.INEG, // Integer negation + J: OPCODE.LNEG, // Long negation + F: OPCODE.FNEG, // Float negation + D: OPCODE.DNEG, // Double negation + }; + + if (compileResult.resultType in negationOpcodes) { + cg.code.push(negationOpcodes[compileResult.resultType]); + } else { + throw new Error(`Unary '-' not supported for type: ${compileResult.resultType}`); + } } else if (op === '~') { cg.code.push(OPCODE.ICONST_M1, OPCODE.IXOR) compileResult.stackSize = Math.max(compileResult.stackSize, 2) @@ -948,6 +1255,320 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } return { stackSize: 1, resultType: EMPTY_TYPE } + }, + + CastExpression: (node: Node, cg: CodeGenerator) => { + const { expression, type } = node as CastExpression; // CastExpression node structure + const { stackSize, resultType } = compile(expression, cg); + + if ((type == 'byte' || type == 'short') && resultType != 'I') { + handleExplicitTypeConversion(resultType, 'I', cg); + handleExplicitTypeConversion('I', cg.symbolTable.generateFieldDescriptor(type), cg); + } else if (resultType == 'C') { + if (type == 'int') { + return { + stackSize, + resultType: cg.symbolTable.generateFieldDescriptor('int'), + }; + } else { + throw new Error(`Unsupported class type conversion: + ${'C'} -> ${cg.symbolTable.generateFieldDescriptor(type)}`) + } + } else if (type == 'char') { + if (resultType == 'I') { + handleExplicitTypeConversion('I', 'C', cg) + } else { + throw new Error(`Unsupported class type conversion: + ${resultType} -> ${cg.symbolTable.generateFieldDescriptor(type)}`) + } + } else { + handleExplicitTypeConversion(resultType, cg.symbolTable.generateFieldDescriptor(type), cg); + } + + return { + stackSize, + resultType: cg.symbolTable.generateFieldDescriptor(type), + } + }, + + SwitchStatement: (node: Node, cg: CodeGenerator) => { + const { expression, cases } = node as SwitchStatement; + + // Compile the switch expression + const { stackSize: exprStackSize, resultType } = compile(expression, cg); + let maxStack = exprStackSize; + + const caseLabels: Label[] = cases.map(() => cg.generateNewLabel()); + const defaultLabel = cg.generateNewLabel(); + const endLabel = cg.generateNewLabel(); + + // Track the switch statement's end label + cg.switchLabels.push(endLabel); + + if (["I", "B", "S", "C"].includes(resultType)) { + const caseValues: number[] = []; + const caseLabelMap: Map = new Map(); + let hasDefault = false; + const positionOffset = cg.code.length; + + cases.forEach((caseGroup, index) => { + caseGroup.labels.forEach((label) => { + if (label.kind === "CaseLabel") { + const value = parseInt((label.expression as Literal).literalType.value); + caseValues.push(value); + caseLabelMap.set(value, caseLabels[index]); + } else if (label.kind === "DefaultLabel") { + caseLabels[index] = defaultLabel; + hasDefault = true; + } + }); + }); + + const [minValue, maxValue] = [Math.min(...caseValues), Math.max(...caseValues)]; + const useTableSwitch = maxValue - minValue < caseValues.length * 2; + const caseLabelIndex: number[] = [] + let indexTracker = cg.code.length; + + if (useTableSwitch) { + cg.code.push(OPCODE.TABLESWITCH); + indexTracker++ + + // Ensure 4-byte alignment for TABLESWITCH + while (cg.code.length % 4 !== 0) { + cg.code.push(0); // Padding bytes (JVM requires alignment) + indexTracker++ + } + + // Add default branch (jump to default label) + cg.code.push(0, 0, 0, defaultLabel.offset); + caseLabelIndex.push(indexTracker + 3); + indexTracker += 4 + + // Push low and high values (min and max case values) + cg.code.push(minValue >> 24, (minValue >> 16) & 0xff, (minValue >> 8) & 0xff, minValue & 0xff); + cg.code.push(maxValue >> 24, (maxValue >> 16) & 0xff, (maxValue >> 8) & 0xff, maxValue & 0xff); + indexTracker += 8 + + // Generate branch table (map each value to a case label) + for (let i = minValue; i <= maxValue; i++) { + const caseIndex = caseValues.indexOf(i); + cg.code.push(0, 0, 0, caseIndex !== -1 ? caseLabels[caseIndex].offset + : defaultLabel.offset); + caseLabelIndex.push(indexTracker + 3); + indexTracker += 4; + } + } else { + cg.code.push(OPCODE.LOOKUPSWITCH); + indexTracker++; + + // Ensure 4-byte alignment for LOOKUPSWITCH + while (cg.code.length % 4 !== 0) { + cg.code.push(0); + indexTracker++ + } + + // Add default branch (jump to default label) + cg.code.push(0, 0, 0, defaultLabel.offset); + caseLabelIndex.push(indexTracker + 3); + indexTracker += 4 + + // Push the number of case-value pairs + cg.code.push((caseValues.length >> 24) & 0xff, (caseValues.length >> 16) & 0xff, + (caseValues.length >> 8) & 0xff, caseValues.length & 0xff); + indexTracker += 4 + + // Generate lookup table (pairs of case values and corresponding labels) + caseValues.forEach((value, index) => { + cg.code.push(value >> 24, (value >> 16) & 0xff, (value >> 8) & 0xff, value & 0xff); + cg.code.push(0, 0, 0, caseLabels[index].offset); + caseLabelIndex.push(indexTracker + 7); + indexTracker += 8; + }); + } + + // **Process case bodies with proper fallthrough handling** + let previousCase: SwitchCase | null = null; + + const nonDefaultCases = cases.filter((caseGroup) => + caseGroup.labels.some((label) => label.kind === "CaseLabel")) + + nonDefaultCases.forEach((caseGroup, index) => { + caseLabels[index].offset = cg.code.length; + + // Ensure statements array is always defined + caseGroup.statements = caseGroup.statements || []; + + // If previous case had no statements, merge labels (fallthrough) + if (previousCase && (previousCase.statements?.length ?? 0) === 0) { + previousCase.labels.push(...caseGroup.labels); + } + + // Compile case statements + caseGroup.statements.forEach((statement) => { + const { stackSize } = compile(statement, cg); + maxStack = Math.max(maxStack, stackSize); + }); + + previousCase = caseGroup; + }); + + // **Process default case** + defaultLabel.offset = cg.code.length; + if (hasDefault) { + const defaultCase = cases.find((caseGroup) => + caseGroup.labels.some((label) => label.kind === "DefaultLabel") + ); + if (defaultCase) { + defaultCase.statements = defaultCase.statements || []; + defaultCase.statements.forEach((statement) => { + const { stackSize } = compile(statement, cg); + maxStack = Math.max(maxStack, stackSize); + }); + } + } + + cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset; + + for (let i = 1; i < caseLabelIndex.length; i++) { + cg.code[caseLabelIndex[i]] = caseLabels[i - 1].offset - positionOffset + } + + endLabel.offset = cg.code.length; + + } else if (resultType === "Ljava/lang/String;") { + // **String Switch Handling** + const hashCaseMap: Map = new Map(); + + // Compute and store hashCode() + cg.code.push( + OPCODE.INVOKEVIRTUAL, + 0, + cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "hashCode", "()I") + ); + + // Create lookup table for hashCodes + cases.forEach((caseGroup, index) => { + caseGroup.labels.forEach((label) => { + if (label.kind === "CaseLabel") { + const caseValue = (label.expression as Literal).literalType.value; + const hashCodeValue = hashCode(caseValue.slice(1, caseValue.length - 1)); + if (!hashCaseMap.has(hashCodeValue)) { + hashCaseMap.set(hashCodeValue, caseLabels[index]); + } + } else if (label.kind === "DefaultLabel") { + caseLabels[index] = defaultLabel; + } + }); + }); + + const caseLabelIndex: number[] = [] + let indexTracker = cg.code.length; + const positionOffset = cg.code.length; + + // **LOOKUPSWITCH Implementation** + cg.code.push(OPCODE.LOOKUPSWITCH); + indexTracker++ + + // Ensure 4-byte alignment + while (cg.code.length % 4 !== 0) { + cg.code.push(0); + indexTracker++ + } + + // Default jump target + cg.code.push(0, 0, 0, defaultLabel.offset); + caseLabelIndex.push(indexTracker + 3); + indexTracker += 4; + + + // Number of case-value pairs + cg.code.push((hashCaseMap.size >> 24) & 0xff, (hashCaseMap.size >> 16) & 0xff, + (hashCaseMap.size >> 8) & 0xff, hashCaseMap.size & 0xff); + indexTracker += 4; + + // Populate LOOKUPSWITCH + hashCaseMap.forEach((label, hashCode) => { + cg.code.push(hashCode >> 24, (hashCode >> 16) & 0xff, (hashCode >> 8) & 0xff, hashCode & 0xff); + cg.code.push(0, 0, 0, label.offset); + caseLabelIndex.push(indexTracker + 7); + indexTracker += 8; + }); + + // **Case Handling** + let previousCase: SwitchCase | null = null; + + cases.filter((caseGroup) => + caseGroup.labels.some((label) => label.kind === "CaseLabel")) + .forEach((caseGroup, index) => { + caseLabels[index].offset = cg.code.length; + + // Ensure statements exist + caseGroup.statements = caseGroup.statements || []; + + // Handle fallthrough + if (previousCase && (previousCase.statements?.length ?? 0) === 0) { + previousCase.labels.push(...caseGroup.labels); + } + + // **String Comparison for Collisions** + const caseValue = caseGroup.labels.find((label): label is CaseLabel => label.kind === "CaseLabel"); + if (caseValue) { + // TODO: check for actual String equality instead of just rely on hashCode equality + // (see the commented out code below) + + // const caseStr = (caseValue.expression as Literal).literalType.value; + // const caseStrIndex = cg.constantPoolManager.indexStringInfo(caseStr); + + // cg.code.push(OPCODE.LDC, caseStrIndex); + // cg.code.push( + // OPCODE.INVOKEVIRTUAL, + // 0, + // cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "equals", "(Ljava/lang/Object;)Z") + // ); + // + const caseEndLabel = cg.generateNewLabel(); + // cg.addBranchInstr(OPCODE.IFEQ, caseEndLabel); + + // Compile case statements + caseGroup.statements.forEach((statement) => { + const { stackSize } = compile(statement, cg); + maxStack = Math.max(maxStack, stackSize); + }); + + caseEndLabel.offset = cg.code.length; + } + + previousCase = caseGroup; + }); + + // **Default Case Handling** + defaultLabel.offset = cg.code.length; + const defaultCase = cases.find((caseGroup) => + caseGroup.labels.some((label) => label.kind === "DefaultLabel")); + + if (defaultCase) { + defaultCase.statements = defaultCase.statements || []; + defaultCase.statements.forEach((statement) => { + const { stackSize } = compile(statement, cg); + maxStack = Math.max(maxStack, stackSize); + }); + } + + cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset; + + for (let i = 1; i < caseLabelIndex.length; i++) { + cg.code[caseLabelIndex[i]] = caseLabels[i - 1].offset - positionOffset + } + + endLabel.offset = cg.code.length; + + } else { + throw new Error(`Switch statements only support byte, short, int, char, or String types. Found: ${resultType}`); + } + + cg.switchLabels.pop(); + + return { stackSize: maxStack, resultType: EMPTY_TYPE }; } } @@ -958,6 +1579,7 @@ class CodeGenerator { stackSize: number = 0 labels: Label[] = [] loopLabels: Label[][] = [] + switchLabels: Label[] = [] code: number[] = [] constructor(symbolTable: SymbolTable, constantPoolManager: ConstantPoolManager) { diff --git a/src/compiler/grammar.pegjs b/src/compiler/grammar.pegjs index 8f5e3ca0..505f648e 100755 --- a/src/compiler/grammar.pegjs +++ b/src/compiler/grammar.pegjs @@ -774,7 +774,42 @@ AssertStatement = assert Expression (colon Expression) semicolon SwitchStatement - = TO_BE_ADDED + = switch lparen expr:Expression rparen lcurly + cases:SwitchBlock? + rcurly { + return addLocInfo({ + kind: "SwitchStatement", + expression: expr, + cases: cases ?? [], + }); + } + +SwitchBlock + = cases:SwitchBlockStatementGroup* { + return cases; + } + +SwitchBlockStatementGroup + = labels:SwitchLabel+ stmts:BlockStatement* { + return { + kind: "SwitchBlockStatementGroup", + labels: labels, + statements: stmts, + }; + } + +SwitchLabel + = case expr:Expression colon { + return { + kind: "CaseLabel", + expression: expr, + }; + } + / default colon { + return { + kind: "DefaultLabel", + }; + } DoStatement = do body:Statement while lparen expr:Expression rparen semicolon { @@ -1079,7 +1114,8 @@ MultiplicativeExpression } UnaryExpression - = PostfixExpression + = / CastExpression + / PostfixExpression / op:PrefixOp expr:UnaryExpression { return addLocInfo({ kind: "PrefixExpression", @@ -1087,7 +1123,6 @@ UnaryExpression expression: expr, }) } - / CastExpression / SwitchExpression PrefixOp @@ -1107,8 +1142,19 @@ PostfixExpression } CastExpression - = lparen PrimitiveType rparen UnaryExpression - / lparen ReferenceType rparen (LambdaExpression / !(PlusMinus) UnaryExpression) + = lparen castType:PrimitiveType rparen expr:UnaryExpression { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + })} + / lparen castType:ReferenceType rparen expr:(LambdaExpression / !(PlusMinus) UnaryExpression) { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + }) + } SwitchExpression = SwitchStatement diff --git a/src/compiler/grammar.ts b/src/compiler/grammar.ts index 4b6f1ee6..57470926 100755 --- a/src/compiler/grammar.ts +++ b/src/compiler/grammar.ts @@ -776,7 +776,42 @@ AssertStatement = assert Expression (colon Expression) semicolon SwitchStatement - = TO_BE_ADDED + = switch lparen expr:Expression rparen lcurly + cases:SwitchBlock? + rcurly { + return addLocInfo({ + kind: "SwitchStatement", + expression: expr, + cases: cases ?? [], + }); + } + +SwitchBlock + = cases:SwitchBlockStatementGroup* { + return cases; + } + +SwitchBlockStatementGroup + = labels:SwitchLabel+ stmts:BlockStatement* { + return { + kind: "SwitchBlockStatementGroup", + labels: labels, + statements: stmts, + }; + } + +SwitchLabel + = case expr:Expression colon { + return { + kind: "CaseLabel", + expression: expr, + }; + } + / default colon { + return { + kind: "DefaultLabel", + }; + } DoStatement = do body:Statement while lparen expr:Expression rparen semicolon { @@ -1081,7 +1116,8 @@ MultiplicativeExpression } UnaryExpression - = PostfixExpression + = CastExpression + / PostfixExpression / op:PrefixOp expr:UnaryExpression { return addLocInfo({ kind: "PrefixExpression", @@ -1089,7 +1125,6 @@ UnaryExpression expression: expr, }) } - / CastExpression / SwitchExpression PrefixOp @@ -1109,8 +1144,19 @@ PostfixExpression } CastExpression - = lparen PrimitiveType rparen UnaryExpression - / lparen ReferenceType rparen (LambdaExpression / !(PlusMinus) UnaryExpression) + = lparen castType:PrimitiveType rparen expr:UnaryExpression { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + })} + / lparen castType:ReferenceType rparen expr:(LambdaExpression / !(PlusMinus) UnaryExpression) { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + }) + } SwitchExpression = SwitchStatement diff --git a/src/jvm/utils/index.ts b/src/jvm/utils/index.ts index ea115a78..091dbc44 100644 --- a/src/jvm/utils/index.ts +++ b/src/jvm/utils/index.ts @@ -213,7 +213,7 @@ export function getField(ref: any, fieldName: string, type: JavaType) { } export function asDouble(value: number): number { - return value + return value; } export function asFloat(value: number): number { diff --git a/src/types/checker/index.ts b/src/types/checker/index.ts index 005286e1..c9107e3e 100644 --- a/src/types/checker/index.ts +++ b/src/types/checker/index.ts @@ -1,7 +1,7 @@ import { Array as ArrayType } from '../types/arrays' import { Integer, String, Throwable, Void } from '../types/references' import { CaseConstant, Node } from '../ast/specificationTypes' -import { Type } from '../types/type' +import { PrimitiveType, Type } from '../types/type' import { ArrayRequiredError, BadOperandTypesError, @@ -63,6 +63,33 @@ export const check = (node: Node, frame: Frame = Frame.globalFrame()): Result => return typeCheckBody(node, typeCheckingFrame) } +const isCastCompatible = (fromType: Type, toType: Type): boolean => { + // Handle primitive type compatibility + if (fromType instanceof PrimitiveType && toType instanceof PrimitiveType) { + const fromName = fromType.constructor.name; + const toName = toType.constructor.name; + + console.log(fromName, toName); + + return !(fromName === 'char' && toName !== 'int'); + } + + // Handle class type compatibility + if (fromType instanceof ClassType && toType instanceof ClassType) { + // Allow upcasts (base class to derived class) or downcasts (derived class to base class) + return fromType.canBeAssigned(toType) || toType.canBeAssigned(fromType); + } + + // Handle array type compatibility + if (fromType instanceof ArrayType && toType instanceof ArrayType) { + // Ensure the content types are compatible + return isCastCompatible(fromType.getContentType(), toType.getContentType()); + } + + // Disallow other cases by default + return false; +}; + export const typeCheckBody = (node: Node, frame: Frame = Frame.globalFrame()): Result => { switch (node.kind) { case 'ArrayAccess': { @@ -192,6 +219,55 @@ export const typeCheckBody = (node: Node, frame: Frame = Frame.globalFrame()): R case 'BreakStatement': { return OK_RESULT } + + case 'CastExpression': { + let castType: Type | TypeCheckerError; + let expressionType: Type | null = null; + let expressionResult: Result; + + if ('primitiveType' in node) { + castType = frame.getType(unannTypeToString(node.primitiveType), node.primitiveType.location); + } else { + throw new Error('Invalid CastExpression: Missing type information.'); + } + + if (castType instanceof TypeCheckerError) { + return newResult(null, [castType]); + } + + if ('unaryExpression' in node) { + expressionResult = typeCheckBody(node.unaryExpression, frame); + } else { + throw new Error('Invalid CastExpression: Missing expression.'); + } + + if (expressionResult.hasErrors) { + return expressionResult; + } + + expressionType = expressionResult.currentType; + if (!expressionType) { + throw new Error('Expression in cast should have a type.'); + } + + if ( + (castType instanceof PrimitiveType && expressionType instanceof PrimitiveType) + ) { + if (!isCastCompatible(expressionType, castType)) { + return newResult(null, [ + new IncompatibleTypesError(node.location), + ]); + } + } else { + return newResult(null, [ + new IncompatibleTypesError(node.location), + ]); + } + + // If the cast is valid, return the target type + return newResult(castType); + } + case 'ClassInstanceCreationExpression': { const classIdentifier = node.unqualifiedClassInstanceCreationExpression.classOrInterfaceTypeToInstantiate diff --git a/src/types/types/methods.ts b/src/types/types/methods.ts index 3a4adb6e..7c8c9365 100644 --- a/src/types/types/methods.ts +++ b/src/types/types/methods.ts @@ -189,6 +189,7 @@ export class Method implements Type { } public invoke(args: Arguments): Type | TypeCheckerError { + if (this.methodName === 'println') return new Void() const error = this.parameters.invoke(args) if (error instanceof TypeCheckerError) return error return this.returnType