From 70eedf6d95db7de6ef26e67ec0f97cc13cea7728 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 14 Jan 2025 14:53:30 +0800 Subject: [PATCH 01/29] Add opcodes for type conversions to code-generator.ts --- src/compiler/__tests__/index.ts | 2 + .../tests/assignmentExpression.test.ts | 55 ++++++++++++++ src/compiler/code-generator.ts | 74 +++++++++++++++++-- 3 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 src/compiler/__tests__/tests/assignmentExpression.test.ts diff --git a/src/compiler/__tests__/index.ts b/src/compiler/__tests__/index.ts index 8f31ef5c..b5b6e742 100644 --- a/src/compiler/__tests__/index.ts +++ b/src/compiler/__tests__/index.ts @@ -9,6 +9,7 @@ import { methodInvocationTest } from "./tests/methodInvocation.test"; import { importTest } from "./tests/import.test"; import { arrayTest } from "./tests/array.test"; import { classTest } from "./tests/class.test"; +import { assignmentExpressionTest } from './tests/assignmentExpression.test' describe("compiler tests", () => { printlnTest(); @@ -22,4 +23,5 @@ describe("compiler tests", () => { importTest(); arrayTest(); classTest(); + assignmentExpressionTest(); }) \ No newline at end of file diff --git a/src/compiler/__tests__/tests/assignmentExpression.test.ts b/src/compiler/__tests__/tests/assignmentExpression.test.ts new file mode 100644 index 00000000..688a5d17 --- /dev/null +++ b/src/compiler/__tests__/tests/assignmentExpression.test.ts @@ -0,0 +1,55 @@ +import { + runTest, + testCase, +} from "../__utils__/test-utils"; + +const testCases: testCase[] = [ + { + comment: "int to double assignment", + program: ` + public class Main { + public static void main(String[] args) { + int x = 5; + double y = x; + System.out.println(y); + } + } + `, + expectedLines: ["5.0"], + }, + { + comment: "int to double conversion", + program: ` + public class Main { + public static void main(String[] args) { + int x = 5; + double y; + y = x; + System.out.println(y); + } + } + `, + expectedLines: ["5.0"], + }, + { + comment: "int to double conversion, array", + program: ` + public class Main { + public static void main(String[] args) { + int x = 6; + double[] y = {1.0, 2.0, 3.0, 4.0, 5.0}; + y[1] = x; + System.out.println(y[1]); + } + } + `, + expectedLines: ["6.0"], + }, +]; + +export const assignmentExpressionTest = () => describe("assignment expression", () => { + for (let testCase of testCases) { + const { comment: comment, program: program, expectedLines: expectedLines } = testCase; + it(comment, () => runTest(program, expectedLines)); + } +}); diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 1b0d8c86..2a75f66f 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -164,12 +164,64 @@ const normalStoreOp: { [type: string]: OPCODE } = { Z: OPCODE.ISTORE } +// const typeConversions: { [key: string]: OPCODE } = { +// 'I->F': OPCODE.I2F, +// 'I->D': OPCODE.I2D, +// 'I->L': OPCODE.I2L, +// 'I->B': OPCODE.I2B, +// 'I->C': OPCODE.I2C, +// 'I->S': OPCODE.I2S, +// 'F->D': OPCODE.F2D, +// 'F->I': OPCODE.F2I, +// 'F->L': OPCODE.F2L, +// 'D->F': OPCODE.D2F, +// 'D->I': OPCODE.D2I, +// 'D->L': OPCODE.D2L, +// 'L->I': OPCODE.L2I, +// 'L->F': OPCODE.L2F, +// 'L->D': OPCODE.L2D +// }; + +const typeConversionsImplicit: { [key: string]: OPCODE } = { + 'I->F': OPCODE.I2F, + 'I->D': OPCODE.I2D, + 'I->L': OPCODE.I2L, + 'F->D': OPCODE.F2D, + 'L->F': OPCODE.L2F, + 'L->D': OPCODE.L2D +}; + type CompileResult = { stackSize: number resultType: string } const EMPTY_TYPE: string = '' +function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { + if (fromType === toType) { + return; + } + const conversionKey = `${fromType}->${toType}`; + if (conversionKey in typeConversionsImplicit) { + cg.code.push(typeConversionsImplicit[conversionKey]); + } else { + throw new Error(`Unsupported implicit type conversion: ${conversionKey}`); + } +} + +// function handleExplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { +// if (fromType === toType) { +// return; +// } +// const conversionKey = `${fromType}->${toType}`; +// if (conversionKey in typeConversions) { +// cg.code.push(typeConversions[conversionKey]); +// } else { +// throw new Error(`Unsupported explicit type conversion: ${conversionKey}`); +// } +// } + + const isNullLiteral = (node: Node) => { return node.kind === 'Literal' && node.literalType.kind === 'NullLiteral' } @@ -245,13 +297,16 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi vi.forEach((val, i) => { cg.code.push(OPCODE.DUP) const size1 = compile(createIntLiteralNode(i), cg).stackSize - const size2 = compile(val as Expression, cg).stackSize + const { stackSize: size2, resultType } = compile(val as Expression, cg); + handleImplicitTypeConversion(resultType, arrayElemType, cg); cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) maxStack = Math.max(maxStack, 2 + size1 + size2) }) cg.code.push(OPCODE.ASTORE, curIdx) } else { - maxStack = Math.max(maxStack, compile(vi, cg).stackSize) + const { stackSize: initializerStackSize, resultType: initializerType } = compile(vi, cg); + handleImplicitTypeConversion(initializerType, variableInfo.typeDescriptor, cg); + maxStack = Math.max(maxStack, initializerStackSize); cg.code.push( variableInfo.typeDescriptor in normalStoreOp ? normalStoreOp[variableInfo.typeDescriptor] @@ -662,15 +717,20 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi if (lhs.kind === 'ArrayAccess') { const { stackSize: size1, resultType: arrayType } = compile(lhs.primary, cg) const size2 = compile(lhs.expression, cg).stackSize - maxStack = size1 + size2 + compile(right, cg).stackSize + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg); + maxStack = size1 + size2 + rhsSize + const arrayElemType = arrayType.slice(1) + handleImplicitTypeConversion(rhsType, arrayElemType, cg); cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) } else if ( lhs.kind === 'ExpressionName' && !Array.isArray(cg.symbolTable.queryVariable(lhs.name)) ) { const info = cg.symbolTable.queryVariable(lhs.name) as VariableInfo - maxStack = 1 + compile(right, cg).stackSize + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg); + handleImplicitTypeConversion(rhsType, info.typeDescriptor, cg); + maxStack = 1 + rhsSize cg.code.push( info.typeDescriptor in normalStoreOp ? normalStoreOp[info.typeDescriptor] : OPCODE.ASTORE, info.index @@ -693,7 +753,11 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.code.push(OPCODE.ALOAD, 0) maxStack += 1 } - maxStack += compile(right, cg).stackSize + + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg); + handleImplicitTypeConversion(rhsType, fieldInfo.typeDescriptor, cg); + + maxStack += rhsSize cg.code.push( fieldInfo.accessFlags & FIELD_FLAGS.ACC_STATIC ? OPCODE.PUTSTATIC : OPCODE.PUTFIELD, 0, From 64e245d3f9897a7fedb53b79b4543940c72a28a7 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 14 Jan 2025 18:27:01 +0800 Subject: [PATCH 02/29] Add implicit type conversions in MethodInvocation to code-generator.ts --- src/compiler/code-generator.ts | 52 ++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 2a75f66f..601aa958 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -197,10 +197,25 @@ type CompileResult = { } const EMPTY_TYPE: string = '' +function areClassTypesCompatible(fromType: string, toType: string): boolean { + const cleanFrom = fromType.replace(/^L|;$/g, ''); + const cleanTo = toType.replace(/^L|;$/g, ''); + return cleanFrom === cleanTo; +} + function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { - if (fromType === toType) { + console.debug(`Converting from: ${fromType}, to: ${toType}`); + if (fromType === toType || toType.replace(/^L|;$/g, '') === 'java/lang/String') { return; } + + if (fromType.startsWith('L') || toType.startsWith('L')) { + if (areClassTypesCompatible(fromType, toType) || fromType === "") { + return; + } + throw new Error(`Unsupported class type conversion: ${fromType} -> ${toType}`); + } + const conversionKey = `${fromType}->${toType}`; if (conversionKey in typeConversionsImplicit) { cg.code.push(typeConversionsImplicit[conversionKey]); @@ -627,6 +642,10 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi let resultType = EMPTY_TYPE const symbolInfos = cg.symbolTable.queryMethod(n.identifier) + if (!symbolInfos || symbolInfos.length === 0) { + throw new Error(`Method not found: ${n.identifier}`); + } + for (let i = 0; i < symbolInfos.length - 1; i++) { if (i === 0) { const varInfo = symbolInfos[i] as VariableInfo @@ -649,10 +668,31 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } const argTypes: Array = [] + + const methodInfo = symbolInfos[symbolInfos.length - 1] as MethodInfos; + if (!methodInfo || methodInfo.length === 0) { + throw new Error(`No method information found for ${n.identifier}`); + } + + const fullDescriptor = methodInfo[0].typeDescriptor; // Full descriptor, e.g., "(Ljava/lang/String;C)V" + const paramDescriptor = fullDescriptor.slice(1, fullDescriptor.indexOf(')')); // Extract "Ljava/lang/String;C" + const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g); + + // Parse individual parameter types + if (params && params.length !== n.argumentList.length) { + throw new Error( + `Parameter mismatch: expected ${params?.length || 0}, got ${n.argumentList.length}` + ); + } + n.argumentList.forEach((x, i) => { - const argCompileResult = compile(x, cg) - maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize) - argTypes.push(argCompileResult.resultType) + const argCompileResult = compile(x, cg); + maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize); + + const expectedType = params?.[i]; // Expected parameter type + handleImplicitTypeConversion(argCompileResult.resultType, expectedType ?? '', cg); + + argTypes.push(argCompileResult.resultType); }) const argDescriptor = '(' + argTypes.join('') + ')' @@ -687,7 +727,9 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } if (!foundMethod) { - throw new InvalidMethodCallError(n.identifier) + throw new InvalidMethodCallError( + `No method matching signature ${n.identifier}${argDescriptor} found.` + ); } return { stackSize: maxStack, resultType: resultType } }, From a33c9aff758bd423b6a7e06177bbaa69be93df9e Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 14 Jan 2025 18:28:12 +0800 Subject: [PATCH 03/29] Add implicit type conversions in MethodInvocation to code-generator.ts --- src/compiler/code-generator.ts | 73 +++++++++++++++++----------------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 601aa958..3d9e74b3 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -189,7 +189,7 @@ const typeConversionsImplicit: { [key: string]: OPCODE } = { 'F->D': OPCODE.F2D, 'L->F': OPCODE.L2F, 'L->D': OPCODE.L2D -}; +} type CompileResult = { stackSize: number @@ -198,29 +198,29 @@ type CompileResult = { const EMPTY_TYPE: string = '' function areClassTypesCompatible(fromType: string, toType: string): boolean { - const cleanFrom = fromType.replace(/^L|;$/g, ''); - const cleanTo = toType.replace(/^L|;$/g, ''); - return cleanFrom === cleanTo; + const cleanFrom = fromType.replace(/^L|;$/g, '') + const cleanTo = toType.replace(/^L|;$/g, '') + return cleanFrom === cleanTo } function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { - console.debug(`Converting from: ${fromType}, to: ${toType}`); + console.debug(`Converting from: ${fromType}, to: ${toType}`) if (fromType === toType || toType.replace(/^L|;$/g, '') === 'java/lang/String') { - return; + return } if (fromType.startsWith('L') || toType.startsWith('L')) { - if (areClassTypesCompatible(fromType, toType) || fromType === "") { - return; + if (areClassTypesCompatible(fromType, toType) || fromType === '') { + return } - throw new Error(`Unsupported class type conversion: ${fromType} -> ${toType}`); + throw new Error(`Unsupported class type conversion: ${fromType} -> ${toType}`) } - const conversionKey = `${fromType}->${toType}`; + const conversionKey = `${fromType}->${toType}` if (conversionKey in typeConversionsImplicit) { - cg.code.push(typeConversionsImplicit[conversionKey]); + cg.code.push(typeConversionsImplicit[conversionKey]) } else { - throw new Error(`Unsupported implicit type conversion: ${conversionKey}`); + throw new Error(`Unsupported implicit type conversion: ${conversionKey}`) } } @@ -236,7 +236,6 @@ function handleImplicitTypeConversion(fromType: string, toType: string, cg: Code // } // } - const isNullLiteral = (node: Node) => { return node.kind === 'Literal' && node.literalType.kind === 'NullLiteral' } @@ -312,16 +311,16 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi vi.forEach((val, i) => { cg.code.push(OPCODE.DUP) const size1 = compile(createIntLiteralNode(i), cg).stackSize - const { stackSize: size2, resultType } = compile(val as Expression, cg); - handleImplicitTypeConversion(resultType, arrayElemType, cg); + const { stackSize: size2, resultType } = compile(val as Expression, cg) + handleImplicitTypeConversion(resultType, arrayElemType, cg) cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) maxStack = Math.max(maxStack, 2 + size1 + size2) }) cg.code.push(OPCODE.ASTORE, curIdx) } else { - const { stackSize: initializerStackSize, resultType: initializerType } = compile(vi, cg); - handleImplicitTypeConversion(initializerType, variableInfo.typeDescriptor, cg); - maxStack = Math.max(maxStack, initializerStackSize); + const { stackSize: initializerStackSize, resultType: initializerType } = compile(vi, cg) + handleImplicitTypeConversion(initializerType, variableInfo.typeDescriptor, cg) + maxStack = Math.max(maxStack, initializerStackSize) cg.code.push( variableInfo.typeDescriptor in normalStoreOp ? normalStoreOp[variableInfo.typeDescriptor] @@ -643,7 +642,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const symbolInfos = cg.symbolTable.queryMethod(n.identifier) if (!symbolInfos || symbolInfos.length === 0) { - throw new Error(`Method not found: ${n.identifier}`); + throw new Error(`Method not found: ${n.identifier}`) } for (let i = 0; i < symbolInfos.length - 1; i++) { @@ -669,30 +668,30 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const argTypes: Array = [] - const methodInfo = symbolInfos[symbolInfos.length - 1] as MethodInfos; + const methodInfo = symbolInfos[symbolInfos.length - 1] as MethodInfos if (!methodInfo || methodInfo.length === 0) { - throw new Error(`No method information found for ${n.identifier}`); + throw new Error(`No method information found for ${n.identifier}`) } - const fullDescriptor = methodInfo[0].typeDescriptor; // Full descriptor, e.g., "(Ljava/lang/String;C)V" - const paramDescriptor = fullDescriptor.slice(1, fullDescriptor.indexOf(')')); // Extract "Ljava/lang/String;C" - const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g); + const fullDescriptor = methodInfo[0].typeDescriptor // Full descriptor, e.g., "(Ljava/lang/String;C)V" + const paramDescriptor = fullDescriptor.slice(1, fullDescriptor.indexOf(')')) // Extract "Ljava/lang/String;C" + const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) // Parse individual parameter types if (params && params.length !== n.argumentList.length) { throw new Error( `Parameter mismatch: expected ${params?.length || 0}, got ${n.argumentList.length}` - ); + ) } n.argumentList.forEach((x, i) => { - const argCompileResult = compile(x, cg); - maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize); + const argCompileResult = compile(x, cg) + maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize) - const expectedType = params?.[i]; // Expected parameter type - handleImplicitTypeConversion(argCompileResult.resultType, expectedType ?? '', cg); + const expectedType = params?.[i] // Expected parameter type + handleImplicitTypeConversion(argCompileResult.resultType, expectedType ?? '', cg) - argTypes.push(argCompileResult.resultType); + argTypes.push(argCompileResult.resultType) }) const argDescriptor = '(' + argTypes.join('') + ')' @@ -729,7 +728,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi if (!foundMethod) { throw new InvalidMethodCallError( `No method matching signature ${n.identifier}${argDescriptor} found.` - ); + ) } return { stackSize: maxStack, resultType: resultType } }, @@ -759,19 +758,19 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi if (lhs.kind === 'ArrayAccess') { const { stackSize: size1, resultType: arrayType } = compile(lhs.primary, cg) const size2 = compile(lhs.expression, cg).stackSize - const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg); + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) maxStack = size1 + size2 + rhsSize const arrayElemType = arrayType.slice(1) - handleImplicitTypeConversion(rhsType, arrayElemType, cg); + handleImplicitTypeConversion(rhsType, arrayElemType, cg) cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) } else if ( lhs.kind === 'ExpressionName' && !Array.isArray(cg.symbolTable.queryVariable(lhs.name)) ) { const info = cg.symbolTable.queryVariable(lhs.name) as VariableInfo - const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg); - handleImplicitTypeConversion(rhsType, info.typeDescriptor, cg); + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) + handleImplicitTypeConversion(rhsType, info.typeDescriptor, cg) maxStack = 1 + rhsSize cg.code.push( info.typeDescriptor in normalStoreOp ? normalStoreOp[info.typeDescriptor] : OPCODE.ASTORE, @@ -796,8 +795,8 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi maxStack += 1 } - const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg); - handleImplicitTypeConversion(rhsType, fieldInfo.typeDescriptor, cg); + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) + handleImplicitTypeConversion(rhsType, fieldInfo.typeDescriptor, cg) maxStack += rhsSize cg.code.push( From 1848b4df56f67140c12f211c457662683bc01368 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 14 Jan 2025 23:00:08 +0800 Subject: [PATCH 04/29] Handle stack size during primitive type conversions in code-generator.ts --- src/compiler/code-generator.ts | 37 ++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 3d9e74b3..3d4e5a2c 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -203,15 +203,15 @@ function areClassTypesCompatible(fromType: string, toType: string): boolean { return cleanFrom === cleanTo } -function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { +function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator): number { console.debug(`Converting from: ${fromType}, to: ${toType}`) if (fromType === toType || toType.replace(/^L|;$/g, '') === 'java/lang/String') { - return + return 0; } if (fromType.startsWith('L') || toType.startsWith('L')) { if (areClassTypesCompatible(fromType, toType) || fromType === '') { - return + return 0; } throw new Error(`Unsupported class type conversion: ${fromType} -> ${toType}`) } @@ -219,6 +219,13 @@ function handleImplicitTypeConversion(fromType: string, toType: string, cg: Code const conversionKey = `${fromType}->${toType}` if (conversionKey in typeConversionsImplicit) { cg.code.push(typeConversionsImplicit[conversionKey]) + if (!(fromType in ['L', 'D']) && toType in ['L', 'D']) { + return 1; + } else if (!(toType in ['L', 'D']) && fromType in ['L', 'D']) { + return -1; + } else { + return 0; + } } else { throw new Error(`Unsupported implicit type conversion: ${conversionKey}`) } @@ -312,15 +319,15 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.code.push(OPCODE.DUP) const size1 = compile(createIntLiteralNode(i), cg).stackSize const { stackSize: size2, resultType } = compile(val as Expression, cg) - handleImplicitTypeConversion(resultType, arrayElemType, cg) + const stackSizeChange = handleImplicitTypeConversion(resultType, arrayElemType, cg) cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) - maxStack = Math.max(maxStack, 2 + size1 + size2) + maxStack = Math.max(maxStack, 2 + size1 + size2 + stackSizeChange) }) cg.code.push(OPCODE.ASTORE, curIdx) } else { const { stackSize: initializerStackSize, resultType: initializerType } = compile(vi, cg) - handleImplicitTypeConversion(initializerType, variableInfo.typeDescriptor, cg) - maxStack = Math.max(maxStack, initializerStackSize) + const stackSizeChange = handleImplicitTypeConversion(initializerType, variableInfo.typeDescriptor, cg) + maxStack = Math.max(maxStack, initializerStackSize + stackSizeChange) cg.code.push( variableInfo.typeDescriptor in normalStoreOp ? normalStoreOp[variableInfo.typeDescriptor] @@ -686,10 +693,10 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi n.argumentList.forEach((x, i) => { const argCompileResult = compile(x, cg) - maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize) const expectedType = params?.[i] // Expected parameter type - handleImplicitTypeConversion(argCompileResult.resultType, expectedType ?? '', cg) + const stackSizeChange = handleImplicitTypeConversion(argCompileResult.resultType, expectedType ?? '', cg) + maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize + stackSizeChange) argTypes.push(argCompileResult.resultType) }) @@ -759,10 +766,10 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const { stackSize: size1, resultType: arrayType } = compile(lhs.primary, cg) const size2 = compile(lhs.expression, cg).stackSize const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) - maxStack = size1 + size2 + rhsSize const arrayElemType = arrayType.slice(1) - handleImplicitTypeConversion(rhsType, arrayElemType, cg) + const stackSizeChange = handleImplicitTypeConversion(rhsType, arrayElemType, cg) + maxStack = Math.max(maxStack, size1 + size2 + rhsSize + stackSizeChange) cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) } else if ( lhs.kind === 'ExpressionName' && @@ -770,8 +777,8 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi ) { const info = cg.symbolTable.queryVariable(lhs.name) as VariableInfo const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) - handleImplicitTypeConversion(rhsType, info.typeDescriptor, cg) - maxStack = 1 + rhsSize + const stackSizeChange = handleImplicitTypeConversion(rhsType, info.typeDescriptor, cg) + maxStack = Math.max(maxStack, 1 + rhsSize + stackSizeChange) cg.code.push( info.typeDescriptor in normalStoreOp ? normalStoreOp[info.typeDescriptor] : OPCODE.ASTORE, info.index @@ -796,9 +803,9 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) - handleImplicitTypeConversion(rhsType, fieldInfo.typeDescriptor, cg) + const stackSizeChange = handleImplicitTypeConversion(rhsType, fieldInfo.typeDescriptor, cg) - maxStack += rhsSize + maxStack = Math.max(maxStack, maxStack + rhsSize + stackSizeChange) cg.code.push( fieldInfo.accessFlags & FIELD_FLAGS.ACC_STATIC ? OPCODE.PUTSTATIC : OPCODE.PUTFIELD, 0, From f88069edf5844a90c6b45d264757ff2620f0a6b6 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 21 Jan 2025 11:10:24 +0800 Subject: [PATCH 05/29] Handle type conversions in binary expressions in code-generator.ts --- src/compiler/__tests__/tests/println.test.ts | 16 +++ src/compiler/code-generator.ts | 126 ++++++++++++++++--- src/jvm/utils/index.ts | 2 +- 3 files changed, 126 insertions(+), 18 deletions(-) diff --git a/src/compiler/__tests__/tests/println.test.ts b/src/compiler/__tests__/tests/println.test.ts index 4e2e4635..9b5ebbea 100644 --- a/src/compiler/__tests__/tests/println.test.ts +++ b/src/compiler/__tests__/tests/println.test.ts @@ -98,6 +98,22 @@ const testCases: testCase[] = [ `, expectedLines: ["true", "false"], }, + { + comment: "println with concatenated arguments", + program: ` + public class Main { + public static void main(String[] args) { + System.out.println("Hello" + " " + "world!"); + System.out.println("This is an int: " + 123); + System.out.println("This is a float: " + 4.5f); + System.out.println("This is a long: " + 10000000000L); + System.out.println("This is a double: " + 10.3); + } + } + `, + expectedLines: ["Hello world!", "This is an int: 123", "This is a float: 4.5", + "This is a long: 10000000000", "This is a double: 10.3"], + }, { comment: "multiple println statements", program: ` diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 3d4e5a2c..36878210 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -243,6 +243,36 @@ function handleImplicitTypeConversion(fromType: string, toType: string, cg: Code // } // } +function generateStringConversion(valueType: string, cg: CodeGenerator): void { + const stringClass = 'java/lang/String'; + + // Map primitive types to `String.valueOf()` method descriptors + const valueOfDescriptors: { [key: string]: string } = { + I: '(I)Ljava/lang/String;', // int + J: '(J)Ljava/lang/String;', // long + F: '(F)Ljava/lang/String;', // float + D: '(D)Ljava/lang/String;', // double + Z: '(Z)Ljava/lang/String;', // boolean + B: '(B)Ljava/lang/String;', // byte + S: '(S)Ljava/lang/String;', // short + C: '(C)Ljava/lang/String;' // char + }; + + const descriptor = valueOfDescriptors[valueType]; + if (!descriptor) { + throw new Error(`Unsupported primitive type for String conversion: ${valueType}`); + } + + const methodIndex = cg.constantPoolManager.indexMethodrefInfo( + stringClass, + 'valueOf', + descriptor + ); + + cg.code.push(OPCODE.INVOKESTATIC, 0, methodIndex); +} + + const isNullLiteral = (node: Node) => { return node.kind === 'Literal' && node.literalType.kind === 'NullLiteral' } @@ -849,33 +879,95 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } } - const { stackSize: size1, resultType: type } = compile(left, cg) - const { stackSize: size2 } = compile(right, cg) + const { stackSize: size1, resultType: leftType } = compile(left, cg) + const { stackSize: size2, resultType: rightType } = compile(right, cg) + + if (op === '+' && + (leftType === 'Ljava/lang/String;' + || rightType === 'Ljava/lang/String;')) { + console.debug(`String concatenation detected: ${leftType} ${op} ${rightType}`) + + if (leftType !== 'Ljava/lang/String;') { + generateStringConversion(leftType, cg); + } + + if (rightType !== 'Ljava/lang/String;') { + generateStringConversion(rightType, cg); + } + + // Invoke `String.concat` for concatenation + const concatMethodIndex = cg.constantPoolManager.indexMethodrefInfo( + 'java/lang/String', + 'concat', + '(Ljava/lang/String;)Ljava/lang/String;' + ); + cg.code.push(OPCODE.INVOKEVIRTUAL, 0, concatMethodIndex); + + return { + stackSize: Math.max(size1, size2 + 1), // Max stack size plus one for the concatenation + resultType: 'Ljava/lang/String;' + }; + } - switch (type) { + if (leftType !== rightType) { + console.debug( + `Type mismatch detected: leftType=${leftType}, rightType=${rightType}. Applying implicit conversions.` + ); + + if (['D', 'F'].includes(leftType) || ['D', 'F'].includes(rightType)) { + // Promote both to double if one is double, or to float otherwise + if (leftType !== 'D' && rightType === 'D') { + handleImplicitTypeConversion(leftType, 'D', cg); + } else if (leftType === 'D' && rightType !== 'D') { + handleImplicitTypeConversion(rightType, 'D', cg); + } else if (leftType !== 'F' && rightType === 'F') { + handleImplicitTypeConversion(leftType, 'F', cg); + } else if (leftType === 'F' && rightType !== 'F') { + handleImplicitTypeConversion(rightType, 'F', cg); + } + } else if (['J'].includes(leftType) || ['J'].includes(rightType)) { + // Promote both to long if one is long + if (leftType !== 'J' && rightType === 'J') { + handleImplicitTypeConversion(leftType, 'J', cg); + } else if (leftType === 'J' && rightType !== 'J') { + handleImplicitTypeConversion(rightType, 'J', cg); + } + } else { + // Promote both to int as the common type for smaller types like byte, short, char + if (leftType !== 'I') { + handleImplicitTypeConversion(leftType, 'I', cg); + } + if (rightType !== 'I') { + handleImplicitTypeConversion(rightType, 'I', cg); + } + } + } + + // Perform the operation + switch (leftType) { case 'B': - cg.code.push(intBinaryOp[op], OPCODE.I2B) - break + cg.code.push(intBinaryOp[op], OPCODE.I2B); + break; case 'D': - cg.code.push(doubleBinaryOp[op]) - break + cg.code.push(doubleBinaryOp[op]); + break; case 'F': - cg.code.push(floatBinaryOp[op]) - break + cg.code.push(floatBinaryOp[op]); + break; case 'I': - cg.code.push(intBinaryOp[op]) - break + cg.code.push(intBinaryOp[op]); + break; case 'J': - cg.code.push(longBinaryOp[op]) - break + cg.code.push(longBinaryOp[op]); + break; case 'S': - cg.code.push(intBinaryOp[op], OPCODE.I2S) - break + cg.code.push(intBinaryOp[op], OPCODE.I2S); + break; } return { - stackSize: Math.max(size1, 1 + (['D', 'J'].includes(type) ? 1 : 0) + size2), - resultType: type + stackSize: Math.max(size1, 1 + (['D', 'J'].includes(leftType) ? 1 : 0) + size2), + resultType: leftType } }, diff --git a/src/jvm/utils/index.ts b/src/jvm/utils/index.ts index ea115a78..112e2406 100644 --- a/src/jvm/utils/index.ts +++ b/src/jvm/utils/index.ts @@ -213,7 +213,7 @@ export function getField(ref: any, fieldName: string, type: JavaType) { } export function asDouble(value: number): number { - return value + return Number(value) } export function asFloat(value: number): number { From db877207f67641405dcb0ae445074d9ef035c471 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 21 Jan 2025 11:27:26 +0800 Subject: [PATCH 06/29] Fix bugs with float conversions --- src/compiler/code-generator.ts | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 36878210..16b45704 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -177,9 +177,9 @@ const normalStoreOp: { [type: string]: OPCODE } = { // 'D->F': OPCODE.D2F, // 'D->I': OPCODE.D2I, // 'D->L': OPCODE.D2L, -// 'L->I': OPCODE.L2I, -// 'L->F': OPCODE.L2F, -// 'L->D': OPCODE.L2D +// 'J->I': OPCODE.L2I, +// 'J->F': OPCODE.L2F, +// 'J->D': OPCODE.L2D // }; const typeConversionsImplicit: { [key: string]: OPCODE } = { @@ -187,8 +187,8 @@ const typeConversionsImplicit: { [key: string]: OPCODE } = { 'I->D': OPCODE.I2D, 'I->L': OPCODE.I2L, 'F->D': OPCODE.F2D, - 'L->F': OPCODE.L2F, - 'L->D': OPCODE.L2D + 'J->F': OPCODE.L2F, + 'J->D': OPCODE.L2D } type CompileResult = { @@ -219,9 +219,9 @@ function handleImplicitTypeConversion(fromType: string, toType: string, cg: Code const conversionKey = `${fromType}->${toType}` if (conversionKey in typeConversionsImplicit) { cg.code.push(typeConversionsImplicit[conversionKey]) - if (!(fromType in ['L', 'D']) && toType in ['L', 'D']) { + if (!(fromType in ['J', 'D']) && toType in ['J', 'D']) { return 1; - } else if (!(toType in ['L', 'D']) && fromType in ['L', 'D']) { + } else if (!(toType in ['J', 'D']) && fromType in ['J', 'D']) { return -1; } else { return 0; From 6e0a2d7c8e895f399d820600d4764c81ef8ca8aa Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 21 Jan 2025 13:11:52 +0800 Subject: [PATCH 07/29] Allow primitive types to be safely converted to String --- src/compiler/code-generator.ts | 10 +++++----- src/jvm/utils/index.ts | 2 +- src/types/types/references.ts | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 16b45704..f91cc424 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -167,16 +167,16 @@ const normalStoreOp: { [type: string]: OPCODE } = { // const typeConversions: { [key: string]: OPCODE } = { // 'I->F': OPCODE.I2F, // 'I->D': OPCODE.I2D, -// 'I->L': OPCODE.I2L, +// 'I->J': OPCODE.I2L, // 'I->B': OPCODE.I2B, // 'I->C': OPCODE.I2C, // 'I->S': OPCODE.I2S, // 'F->D': OPCODE.F2D, // 'F->I': OPCODE.F2I, -// 'F->L': OPCODE.F2L, +// 'F->J': OPCODE.F2L, // 'D->F': OPCODE.D2F, // 'D->I': OPCODE.D2I, -// 'D->L': OPCODE.D2L, +// 'D->J': OPCODE.D2L, // 'J->I': OPCODE.L2I, // 'J->F': OPCODE.L2F, // 'J->D': OPCODE.L2D @@ -185,7 +185,7 @@ const normalStoreOp: { [type: string]: OPCODE } = { const typeConversionsImplicit: { [key: string]: OPCODE } = { 'I->F': OPCODE.I2F, 'I->D': OPCODE.I2D, - 'I->L': OPCODE.I2L, + 'I->J': OPCODE.I2L, 'F->D': OPCODE.F2D, 'J->F': OPCODE.L2F, 'J->D': OPCODE.L2D @@ -904,7 +904,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.code.push(OPCODE.INVOKEVIRTUAL, 0, concatMethodIndex); return { - stackSize: Math.max(size1, size2 + 1), // Max stack size plus one for the concatenation + stackSize: Math.max(size1 + 1, size2 + 1), // Max stack size plus one for the concatenation resultType: 'Ljava/lang/String;' }; } diff --git a/src/jvm/utils/index.ts b/src/jvm/utils/index.ts index 112e2406..091dbc44 100644 --- a/src/jvm/utils/index.ts +++ b/src/jvm/utils/index.ts @@ -213,7 +213,7 @@ export function getField(ref: any, fieldName: string, type: JavaType) { } export function asDouble(value: number): number { - return Number(value) + return value; } export function asFloat(value: number): number { diff --git a/src/types/types/references.ts b/src/types/types/references.ts index 69c75e6d..edd87ad9 100644 --- a/src/types/types/references.ts +++ b/src/types/types/references.ts @@ -106,7 +106,7 @@ export class String extends ClassType { } public canBeAssigned(type: Type): boolean { - if (type instanceof Primitives.Null) return true + if (type instanceof PrimitiveType) return true return type instanceof String } } From 7e58159a15987f11c614e544dd7e3611dfc36d97 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 21 Jan 2025 13:54:31 +0800 Subject: [PATCH 08/29] Allow primitive types other than boolean in ternary conditional --- src/compiler/code-generator.ts | 60 ++++++++++++++++++++++++++++++++++ src/types/types/primitives.ts | 2 +- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index f91cc424..a6256650 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -272,6 +272,61 @@ function generateStringConversion(valueType: string, cg: CodeGenerator): void { cg.code.push(OPCODE.INVOKESTATIC, 0, methodIndex); } +// function generateBooleanConversion(type: string, cg: CodeGenerator): number { +// let stackChange = 0; // Tracks changes to the stack size +// +// switch (type) { +// case 'I': // int +// case 'B': // byte +// case 'S': // short +// case 'C': // char +// // For integer-like types, compare with zero +// cg.code.push(OPCODE.ICONST_0); // Push 0 +// stackChange += 1; // `ICONST_0` pushes a value onto the stack +// cg.code.push(OPCODE.IF_ICMPNE); // Compare and branch +// stackChange -= 2; // `IF_ICMPNE` consumes two values from the stack +// break; +// +// case 'J': // long +// // For long, compare with zero +// cg.code.push(OPCODE.LCONST_0); // Push 0L +// stackChange += 2; // `LCONST_0` pushes two values onto the stack (long takes 2 slots) +// cg.code.push(OPCODE.LCMP); // Compare top two longs +// stackChange -= 4; // `LCMP` consumes four values (two long operands) and pushes one result +// cg.code.push(OPCODE.IFNE); // If not equal, branch +// stackChange -= 1; // `IFNE` consumes one value (the comparison result) +// break; +// +// case 'F': // float +// // For float, compare with zero +// cg.code.push(OPCODE.FCONST_0); // Push 0.0f +// stackChange += 1; // `FCONST_0` pushes a value onto the stack +// cg.code.push(OPCODE.FCMPL); // Compare top two floats +// stackChange -= 2; // `FCMPL` consumes two values (float operands) and pushes one result +// cg.code.push(OPCODE.IFNE); // If not equal, branch +// stackChange -= 1; // `IFNE` consumes one value (the comparison result) +// break; +// +// case 'D': // double +// // For double, compare with zero +// cg.code.push(OPCODE.DCONST_0); // Push 0.0d +// stackChange += 2; // `DCONST_0` pushes two values onto the stack (double takes 2 slots) +// cg.code.push(OPCODE.DCMPL); // Compare top two doubles +// stackChange -= 4; // `DCMPL` consumes four values (two double operands) and pushes one result +// cg.code.push(OPCODE.IFNE); // If not equal, branch +// stackChange -= 1; // `IFNE` consumes one value (the comparison result) +// break; +// +// case 'Z': // boolean +// // Already a boolean, no conversion needed +// break; +// +// default: +// throw new Error(`Cannot convert type ${type} to boolean.`); +// } +// +// return stackChange; // Return the net change in stack size +// } const isNullLiteral = (node: Node) => { return node.kind === 'Literal' && node.literalType.kind === 'NullLiteral' @@ -535,6 +590,11 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.addBranchInstr(OPCODE.GOTO, targetLabel) } return { stackSize: 0, resultType: cg.symbolTable.generateFieldDescriptor('boolean') } + } else { + if (onTrue === (parseInt(value) !== 0)) { + cg.addBranchInstr(OPCODE.GOTO, targetLabel) + } + return { stackSize: 0, resultType: cg.symbolTable.generateFieldDescriptor('boolean') } } } diff --git a/src/types/types/primitives.ts b/src/types/types/primitives.ts index 1c9c0f6d..4832c6f0 100644 --- a/src/types/types/primitives.ts +++ b/src/types/types/primitives.ts @@ -13,7 +13,7 @@ export class Boolean extends PrimitiveType { } public canBeAssigned(type: Type): boolean { - return type instanceof Boolean + return type instanceof PrimitiveType && !(type instanceof Null) } } From 6bbafe799f6a5c4b59d3edb661db52044f52428e Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 21 Jan 2025 14:35:06 +0800 Subject: [PATCH 09/29] Fix bugs in type checker --- src/types/types/methods.ts | 1 + src/types/types/primitives.ts | 2 +- src/types/types/references.ts | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/types/types/methods.ts b/src/types/types/methods.ts index 3a4adb6e..7c8c9365 100644 --- a/src/types/types/methods.ts +++ b/src/types/types/methods.ts @@ -189,6 +189,7 @@ export class Method implements Type { } public invoke(args: Arguments): Type | TypeCheckerError { + if (this.methodName === 'println') return new Void() const error = this.parameters.invoke(args) if (error instanceof TypeCheckerError) return error return this.returnType diff --git a/src/types/types/primitives.ts b/src/types/types/primitives.ts index 4832c6f0..1c9c0f6d 100644 --- a/src/types/types/primitives.ts +++ b/src/types/types/primitives.ts @@ -13,7 +13,7 @@ export class Boolean extends PrimitiveType { } public canBeAssigned(type: Type): boolean { - return type instanceof PrimitiveType && !(type instanceof Null) + return type instanceof Boolean } } diff --git a/src/types/types/references.ts b/src/types/types/references.ts index edd87ad9..69c75e6d 100644 --- a/src/types/types/references.ts +++ b/src/types/types/references.ts @@ -106,7 +106,7 @@ export class String extends ClassType { } public canBeAssigned(type: Type): boolean { - if (type instanceof PrimitiveType) return true + if (type instanceof Primitives.Null) return true return type instanceof String } } From 57eeb8bcbb757765aadfbbadd7873a91b08410a2 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 21 Jan 2025 15:44:36 +0800 Subject: [PATCH 10/29] Enable unary expressions for non-integer types --- .../tests/arithmeticExpression.test.ts | 52 ++++++++++++++ .../tests/assignmentExpression.test.ts | 69 +++++++++++++++++++ .../__tests__/tests/unaryExpression.test.ts | 39 +++++++++++ src/compiler/code-generator.ts | 13 +++- 4 files changed, 172 insertions(+), 1 deletion(-) diff --git a/src/compiler/__tests__/tests/arithmeticExpression.test.ts b/src/compiler/__tests__/tests/arithmeticExpression.test.ts index abe02048..72a38c1d 100644 --- a/src/compiler/__tests__/tests/arithmeticExpression.test.ts +++ b/src/compiler/__tests__/tests/arithmeticExpression.test.ts @@ -78,6 +78,58 @@ const testCases: testCase[] = [ expectedLines: ["-2147483648", "-32769", "-32768", "-129", "-128", "-1", "0", "1", "127", "128", "32767", "32768", "2147483647"], }, + { + comment: "Mixed int and float addition (order swapped)", + program: ` + public class Main { + public static void main(String[] args) { + int a = 5; + float b = 2.5f; + System.out.println(a + b); + } + } + `, + expectedLines: ["7.5"], + }, + { + comment: "Mixed long and double multiplication", + program: ` + public class Main { + public static void main(String[] args) { + double a = 3.5; + long b = 10L; + System.out.println(a * b); + } + } + `, + expectedLines: ["35.0"], + }, + { + comment: "Mixed long and double multiplication (order swapped)", + program: ` + public class Main { + public static void main(String[] args) { + long a = 10L; + double b = 3.5; + System.out.println(a * b); + } + } + `, + expectedLines: ["35.0"], + }, + { + comment: "Mixed int and double division", + program: ` + public class Main { + public static void main(String[] args) { + double a = 2.0; + int b = 5; + System.out.println(a / b); + } + } + `, + expectedLines: ["0.4"], + } ]; export const arithmeticExpressionTest = () => describe("arithmetic expression", () => { diff --git a/src/compiler/__tests__/tests/assignmentExpression.test.ts b/src/compiler/__tests__/tests/assignmentExpression.test.ts index 688a5d17..5950de37 100644 --- a/src/compiler/__tests__/tests/assignmentExpression.test.ts +++ b/src/compiler/__tests__/tests/assignmentExpression.test.ts @@ -45,6 +45,75 @@ const testCases: testCase[] = [ `, expectedLines: ["6.0"], }, + { + comment: "int to long", + program: ` + public class Main { + public static void main(String[] args) { + int a = 123; + long b = a; + System.out.println(b); + } + } + `, + expectedLines: ["123"], + }, + { + comment: "int to float", + program: ` + public class Main { + public static void main(String[] args) { + int a = 123; + float b = a; + System.out.println(b); + } + } + `, + expectedLines: ["123.0"], + }, + + // long -> other types + { + comment: "long to float", + program: ` + public class Main { + public static void main(String[] args) { + long a = 9223372036854775807L; + float b = a; + System.out.println(b); + } + } + `, + expectedLines: ["9.223372E18"], + }, + { + comment: "long to double", + program: ` + public class Main { + public static void main(String[] args) { + long a = 9223372036854775807L; + double b = a; + System.out.println(b); + } + } + `, + expectedLines: ["9.223372036854776E18"], + }, + + // float -> other types + { + comment: "float to double", + program: ` + public class Main { + public static void main(String[] args) { + float a = 3.0f; + double b = a; + System.out.println(b); + } + } + `, + expectedLines: ["3.0"], + }, ]; export const assignmentExpressionTest = () => describe("assignment expression", () => { diff --git a/src/compiler/__tests__/tests/unaryExpression.test.ts b/src/compiler/__tests__/tests/unaryExpression.test.ts index 175c6a2d..0e9aa469 100644 --- a/src/compiler/__tests__/tests/unaryExpression.test.ts +++ b/src/compiler/__tests__/tests/unaryExpression.test.ts @@ -159,6 +159,45 @@ const testCases: testCase[] = [ }`, expectedLines: ["10", "10", "-10", "-10", "-10", "-10", "10", "9", "-10"], }, + { + comment: "unary plus/minus for long", + program: ` + public class Main { + public static void main(String[] args) { + long a = 9223372036854775807L; + System.out.println(+a); + System.out.println(-a); + } + } + `, + expectedLines: ["9223372036854775807", "-9223372036854775807"], + }, + { + comment: "unary plus/minus for float", + program: ` + public class Main { + public static void main(String[] args) { + float a = 4.5f; + System.out.println(+a); + System.out.println(-a); + } + } + `, + expectedLines: ["4.5", "-4.5"], + }, + { + comment: "unary plus/minus for double", + program: ` + public class Main { + public static void main(String[] args) { + double a = 10.75; + System.out.println(+a); + System.out.println(-a); + } + } + `, + expectedLines: ["10.75", "-10.75"], + }, { comment: "bitwise complement", program: ` diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index a6256650..715d89f6 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -1063,7 +1063,18 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const compileResult = compile(expr, cg) if (op === '-') { - cg.code.push(OPCODE.INEG) + const negationOpcodes: { [type: string]: OPCODE } = { + I: OPCODE.INEG, // Integer negation + J: OPCODE.LNEG, // Long negation + F: OPCODE.FNEG, // Float negation + D: OPCODE.DNEG, // Double negation + }; + + if (compileResult.resultType in negationOpcodes) { + cg.code.push(negationOpcodes[compileResult.resultType]); + } else { + throw new Error(`Unary '-' not supported for type: ${compileResult.resultType}`); + } } else if (op === '~') { cg.code.push(OPCODE.ICONST_M1, OPCODE.IXOR) compileResult.stackSize = Math.max(compileResult.stackSize, 2) From 2a2cc0b829a803c014966bd753d4c77f93777f8c Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Sat, 25 Jan 2025 14:13:21 +0800 Subject: [PATCH 11/29] Fix bug in type conversion for binary expressions --- src/compiler/code-generator.ts | 40 ++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 715d89f6..f574d5c9 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -940,6 +940,8 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } const { stackSize: size1, resultType: leftType } = compile(left, cg) + const insertConversionIndex = cg.code.length; + cg.code.push(OPCODE.NOP); const { stackSize: size2, resultType: rightType } = compile(right, cg) if (op === '+' && @@ -969,42 +971,58 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi }; } + let finalType = leftType; + if (leftType !== rightType) { console.debug( `Type mismatch detected: leftType=${leftType}, rightType=${rightType}. Applying implicit conversions.` ); + const conversionKeyLeft = `${leftType}->${rightType}` + const conversionKeyRight = `${rightType}->${leftType}` + if (['D', 'F'].includes(leftType) || ['D', 'F'].includes(rightType)) { // Promote both to double if one is double, or to float otherwise if (leftType !== 'D' && rightType === 'D') { - handleImplicitTypeConversion(leftType, 'D', cg); + cg.code.fill(typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, insertConversionIndex + 1) + finalType = 'D'; } else if (leftType === 'D' && rightType !== 'D') { - handleImplicitTypeConversion(rightType, 'D', cg); + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + finalType = 'D'; } else if (leftType !== 'F' && rightType === 'F') { - handleImplicitTypeConversion(leftType, 'F', cg); + // handleImplicitTypeConversion(leftType, 'F', cg); + cg.code.fill(typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, insertConversionIndex + 1) + finalType = 'F'; } else if (leftType === 'F' && rightType !== 'F') { - handleImplicitTypeConversion(rightType, 'F', cg); + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + finalType = 'F'; } } else if (['J'].includes(leftType) || ['J'].includes(rightType)) { // Promote both to long if one is long if (leftType !== 'J' && rightType === 'J') { - handleImplicitTypeConversion(leftType, 'J', cg); + cg.code.fill(typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, insertConversionIndex + 1) } else if (leftType === 'J' && rightType !== 'J') { - handleImplicitTypeConversion(rightType, 'J', cg); + cg.code.push(typeConversionsImplicit[conversionKeyRight]) } + finalType = 'J'; } else { // Promote both to int as the common type for smaller types like byte, short, char if (leftType !== 'I') { - handleImplicitTypeConversion(leftType, 'I', cg); + cg.code.fill(typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, insertConversionIndex + 1) } if (rightType !== 'I') { - handleImplicitTypeConversion(rightType, 'I', cg); + cg.code.push(typeConversionsImplicit[conversionKeyRight]) } + finalType = 'I'; } } // Perform the operation - switch (leftType) { + switch (finalType) { case 'B': cg.code.push(intBinaryOp[op], OPCODE.I2B); break; @@ -1026,8 +1044,8 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } return { - stackSize: Math.max(size1, 1 + (['D', 'J'].includes(leftType) ? 1 : 0) + size2), - resultType: leftType + stackSize: Math.max(size1, 1 + (['D', 'J'].includes(finalType) ? 1 : 0) + size2), + resultType: finalType } }, From 0799bf7f8e1cdff442c4ec00734e394c129bfcff Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Sat, 25 Jan 2025 22:21:44 +0800 Subject: [PATCH 12/29] Add type cast support to the compiler --- src/ast/astExtractor/expression-extractor.ts | 30 ++++ src/ast/types/blocks-and-statements.ts | 7 + .../__tests__/__utils__/test-utils.ts | 1 + src/compiler/__tests__/index.ts | 2 + .../__tests__/tests/castExpression.test.ts | 145 ++++++++++++++++++ src/compiler/code-generator.ts | 103 +++++++++---- src/compiler/grammar.pegjs | 19 ++- src/compiler/grammar.ts | 19 ++- 8 files changed, 286 insertions(+), 40 deletions(-) create mode 100644 src/compiler/__tests__/tests/castExpression.test.ts diff --git a/src/ast/astExtractor/expression-extractor.ts b/src/ast/astExtractor/expression-extractor.ts index 3301db6a..19782358 100644 --- a/src/ast/astExtractor/expression-extractor.ts +++ b/src/ast/astExtractor/expression-extractor.ts @@ -2,6 +2,7 @@ import { ArgumentListCtx, BaseJavaCstVisitorWithDefaults, BinaryExpressionCtx, + CastExpressionCtx, ClassOrInterfaceTypeToInstantiateCtx, BooleanLiteralCtx, ExpressionCstNode, @@ -86,6 +87,35 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults { } } + castExpression(ctx: CastExpressionCtx) { + // Handle primitive cast expressions + if (ctx.primitiveCastExpression && ctx.primitiveCastExpression?.length > 0) { + const primitiveCast = ctx.primitiveCastExpression[0]; + const type = this.extractType(primitiveCast.children.primitiveType[0]); + const expression = this.visit(primitiveCast.children.unaryExpression[0]); + console.debug({primitiveCast, type, expression}); + return { + kind: "CastExpression", + castType: type, + expression: expression, + location: this.location, + }; + } + + throw new Error("Invalid CastExpression format."); + } + + private extractType(typeCtx: any) { + if (typeCtx.Identifier) { + return typeCtx.Identifier[0].image; + } + if (typeCtx.unannPrimitiveType) { + return this.visit(typeCtx.unannPrimitiveType); + } + throw new Error("Invalid type context in cast expression."); + } + + private makeBinaryExpression( operators: IToken[], operands: UnaryExpressionCstNode[] diff --git a/src/ast/types/blocks-and-statements.ts b/src/ast/types/blocks-and-statements.ts index fe5dc7ad..fa2012c4 100644 --- a/src/ast/types/blocks-and-statements.ts +++ b/src/ast/types/blocks-and-statements.ts @@ -119,6 +119,7 @@ export type Expression = | BinaryExpression | UnaryExpression | TernaryExpression + | CastExpression | Void; export interface Void extends BaseNode { @@ -289,3 +290,9 @@ export interface TernaryExpression extends BaseNode { consequent: Expression; alternate: Expression; } + +export interface CastExpression extends BaseNode { + kind: "CastExpression"; + type: UnannType; + expression: Expression; +} diff --git a/src/compiler/__tests__/__utils__/test-utils.ts b/src/compiler/__tests__/__utils__/test-utils.ts index 36d11f57..042ca8ef 100644 --- a/src/compiler/__tests__/__utils__/test-utils.ts +++ b/src/compiler/__tests__/__utils__/test-utils.ts @@ -24,6 +24,7 @@ const binaryWriter = new BinaryWriter(); export function runTest(program: string, expectedLines: string[]) { const ast = parser.parse(program); + console.log(JSON.stringify(ast, null, 2) + "\n"); expect(ast).not.toBeNull(); if (debug) { diff --git a/src/compiler/__tests__/index.ts b/src/compiler/__tests__/index.ts index b5b6e742..e8bf633b 100644 --- a/src/compiler/__tests__/index.ts +++ b/src/compiler/__tests__/index.ts @@ -10,8 +10,10 @@ import { importTest } from "./tests/import.test"; import { arrayTest } from "./tests/array.test"; import { classTest } from "./tests/class.test"; import { assignmentExpressionTest } from './tests/assignmentExpression.test' +import { castExpressionTest } from './tests/castExpression.test' describe("compiler tests", () => { + castExpressionTest(); printlnTest(); variableDeclarationTest(); arithmeticExpressionTest(); diff --git a/src/compiler/__tests__/tests/castExpression.test.ts b/src/compiler/__tests__/tests/castExpression.test.ts new file mode 100644 index 00000000..e811ec66 --- /dev/null +++ b/src/compiler/__tests__/tests/castExpression.test.ts @@ -0,0 +1,145 @@ +import { + runTest, + testCase, +} from "../__utils__/test-utils"; + +const testCases: testCase[] = [ + { + comment: "Simple primitive casting: int to float", + program: ` + public class Main { + public static void main(String[] args) { + int a = 5; + float b = (float) a; + System.out.println(b); + } + } + `, + expectedLines: ["5.0"], + }, + { + comment: "Simple primitive casting: float to int", + program: ` + public class Main { + public static void main(String[] args) { + float a = 2.9f; + int b = (int) a; + System.out.println(b); + } + } + `, + expectedLines: ["2"], + }, + { + comment: "Primitive casting: double to long", + program: ` + public class Main { + public static void main(String[] args) { + double a = 123456789.987; + long b = (long) a; + System.out.println(b); + } + } + `, + expectedLines: ["123456789"], + }, + { + comment: "Primitive casting: long to byte", + program: ` + public class Main { + public static void main(String[] args) { + long a = 257; + byte b = (byte) a; + System.out.println(b); + } + } + `, + expectedLines: ["1"], // byte wraps around at 256 + }, + { + comment: "Primitive casting: char to int", + program: ` + public class Main { + public static void main(String[] args) { + char a = 'A'; + int b = (int) a; + System.out.println(b); + } + } + `, + expectedLines: ["65"], + }, + { + comment: "Primitive casting: int to char", + program: ` + public class Main { + public static void main(String[] args) { + int a = 65; + char b = (char) a; + System.out.println(b); + } + } + `, + expectedLines: ["A"], + }, + { + comment: "Primitive casting: int to char", + program: ` + public class Main { + public static void main(String[] args) { + int a = 66; + char b = (char) a; + System.out.println(b); + } + } + `, + expectedLines: ["B"], + }, + { + comment: "Primitive casting with loss of precision", + program: ` + public class Main { + public static void main(String[] args) { + double a = 123.456; + int b = (int) a; + System.out.println(b); + } + } + `, + expectedLines: ["123"], + }, + { + comment: "Primitive casting: float to short", + program: ` + public class Main { + public static void main(String[] args) { + float a = 32768.0f; + short b = (short) a; + System.out.println(b); + } + } + `, + expectedLines: ["-32768"], // short wraps around + }, + { + comment: "Chained casting: double to int to byte", + program: ` + public class Main { + public static void main(String[] args) { + double a = 258.99; + int b = (int) a; + byte c = (byte) b; + System.out.println(c); + } + } + `, + expectedLines: ["2"], // 258 -> byte wraps around + }, +]; + +export const castExpressionTest = () => describe("cast expression", () => { + for (let testCase of testCases) { + const { comment: comment, program: program, expectedLines: expectedLines } = testCase; + it(comment, () => runTest(program, expectedLines)); + } +}); \ No newline at end of file diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index f574d5c9..5f76f4a8 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -24,7 +24,7 @@ import { ClassInstanceCreationExpression, ExpressionStatement, TernaryExpression, - LeftHandSide + LeftHandSide, CastExpression } from '../ast/types/blocks-and-statements' import { MethodDeclaration, UnannType } from '../ast/types/classes' import { ConstantPoolManager } from './constant-pool-manager' @@ -164,23 +164,23 @@ const normalStoreOp: { [type: string]: OPCODE } = { Z: OPCODE.ISTORE } -// const typeConversions: { [key: string]: OPCODE } = { -// 'I->F': OPCODE.I2F, -// 'I->D': OPCODE.I2D, -// 'I->J': OPCODE.I2L, -// 'I->B': OPCODE.I2B, -// 'I->C': OPCODE.I2C, -// 'I->S': OPCODE.I2S, -// 'F->D': OPCODE.F2D, -// 'F->I': OPCODE.F2I, -// 'F->J': OPCODE.F2L, -// 'D->F': OPCODE.D2F, -// 'D->I': OPCODE.D2I, -// 'D->J': OPCODE.D2L, -// 'J->I': OPCODE.L2I, -// 'J->F': OPCODE.L2F, -// 'J->D': OPCODE.L2D -// }; +const typeConversions: { [key: string]: OPCODE } = { + 'I->F': OPCODE.I2F, + 'I->D': OPCODE.I2D, + 'I->J': OPCODE.I2L, + 'I->B': OPCODE.I2B, + 'I->C': OPCODE.I2C, + 'I->S': OPCODE.I2S, + 'F->D': OPCODE.F2D, + 'F->I': OPCODE.F2I, + 'F->J': OPCODE.F2L, + 'D->F': OPCODE.D2F, + 'D->I': OPCODE.D2I, + 'D->J': OPCODE.D2L, + 'J->I': OPCODE.L2I, + 'J->F': OPCODE.L2F, + 'J->D': OPCODE.L2D +}; const typeConversionsImplicit: { [key: string]: OPCODE } = { 'I->F': OPCODE.I2F, @@ -231,17 +231,17 @@ function handleImplicitTypeConversion(fromType: string, toType: string, cg: Code } } -// function handleExplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { -// if (fromType === toType) { -// return; -// } -// const conversionKey = `${fromType}->${toType}`; -// if (conversionKey in typeConversions) { -// cg.code.push(typeConversions[conversionKey]); -// } else { -// throw new Error(`Unsupported explicit type conversion: ${conversionKey}`); -// } -// } +function handleExplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { + if (fromType === toType) { + return; + } + const conversionKey = `${fromType}->${toType}`; + if (conversionKey in typeConversions) { + cg.code.push(typeConversions[conversionKey]); + } else { + throw new Error(`Unsupported explicit type conversion: ${conversionKey}`); + } +} function generateStringConversion(valueType: string, cg: CodeGenerator): void { const stringClass = 'java/lang/String'; @@ -784,11 +784,16 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi n.argumentList.forEach((x, i) => { const argCompileResult = compile(x, cg) + let normalizedType = argCompileResult.resultType; + if (normalizedType === 'B' || normalizedType === 'S') { + normalizedType = 'I' + } + const expectedType = params?.[i] // Expected parameter type - const stackSizeChange = handleImplicitTypeConversion(argCompileResult.resultType, expectedType ?? '', cg) + const stackSizeChange = handleImplicitTypeConversion(normalizedType, expectedType ?? '', cg) maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize + stackSizeChange) - argTypes.push(argCompileResult.resultType) + argTypes.push(normalizedType) }) const argDescriptor = '(' + argTypes.join('') + ')' @@ -1241,7 +1246,41 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } return { stackSize: 1, resultType: EMPTY_TYPE } - } + }, + + CastExpression: (node: Node, cg: CodeGenerator) => { + const { expression, type } = node as CastExpression; // CastExpression node structure + const { stackSize, resultType } = compile(expression, cg); + + if ((type == 'byte' || type == 'short') && resultType != 'I') { + handleExplicitTypeConversion(resultType, 'I', cg); + handleExplicitTypeConversion('I', cg.symbolTable.generateFieldDescriptor(type), cg); + } else if (resultType == 'C') { + if (type == 'int') { + return { + stackSize, + resultType: cg.symbolTable.generateFieldDescriptor('int'), + }; + } else { + throw new Error(`Unsupported class type conversion: + ${'C'} -> ${cg.symbolTable.generateFieldDescriptor(type)}`) + } + } else if (type == 'char') { + if (resultType == 'I') { + handleExplicitTypeConversion('I', 'C', cg) + } else { + throw new Error(`Unsupported class type conversion: + ${resultType} -> ${cg.symbolTable.generateFieldDescriptor(type)}`) + } + } else { + handleExplicitTypeConversion(resultType, cg.symbolTable.generateFieldDescriptor(type), cg); + } + + return { + stackSize, + resultType: cg.symbolTable.generateFieldDescriptor(type), + } +} } class CodeGenerator { diff --git a/src/compiler/grammar.pegjs b/src/compiler/grammar.pegjs index 8f5e3ca0..8f97b417 100755 --- a/src/compiler/grammar.pegjs +++ b/src/compiler/grammar.pegjs @@ -1079,7 +1079,8 @@ MultiplicativeExpression } UnaryExpression - = PostfixExpression + = / CastExpression + / PostfixExpression / op:PrefixOp expr:UnaryExpression { return addLocInfo({ kind: "PrefixExpression", @@ -1087,7 +1088,6 @@ UnaryExpression expression: expr, }) } - / CastExpression / SwitchExpression PrefixOp @@ -1107,8 +1107,19 @@ PostfixExpression } CastExpression - = lparen PrimitiveType rparen UnaryExpression - / lparen ReferenceType rparen (LambdaExpression / !(PlusMinus) UnaryExpression) + = lparen castType:PrimitiveType rparen expr:UnaryExpression { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + })} + / lparen castType:ReferenceType rparen expr:(LambdaExpression / !(PlusMinus) UnaryExpression) { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + }) + } SwitchExpression = SwitchStatement diff --git a/src/compiler/grammar.ts b/src/compiler/grammar.ts index 4b6f1ee6..4c776ffd 100755 --- a/src/compiler/grammar.ts +++ b/src/compiler/grammar.ts @@ -1081,7 +1081,8 @@ MultiplicativeExpression } UnaryExpression - = PostfixExpression + = CastExpression + / PostfixExpression / op:PrefixOp expr:UnaryExpression { return addLocInfo({ kind: "PrefixExpression", @@ -1089,7 +1090,6 @@ UnaryExpression expression: expr, }) } - / CastExpression / SwitchExpression PrefixOp @@ -1109,8 +1109,19 @@ PostfixExpression } CastExpression - = lparen PrimitiveType rparen UnaryExpression - / lparen ReferenceType rparen (LambdaExpression / !(PlusMinus) UnaryExpression) + = lparen castType:PrimitiveType rparen expr:UnaryExpression { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + })} + / lparen castType:ReferenceType rparen expr:(LambdaExpression / !(PlusMinus) UnaryExpression) { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + }) + } SwitchExpression = SwitchStatement From fe7660387a0ba0e0cf987100002d3171e84f5479 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Sun, 26 Jan 2025 13:39:48 +0800 Subject: [PATCH 13/29] Add type cast support to the type checker --- .../__tests__/expression-extractor.test.ts | 139 ++++++++++++++++++ src/ast/astExtractor/expression-extractor.ts | 47 +++++- src/ast/types/blocks-and-statements.ts | 3 +- src/types/checker/index.ts | 78 +++++++++- 4 files changed, 256 insertions(+), 11 deletions(-) diff --git a/src/ast/__tests__/expression-extractor.test.ts b/src/ast/__tests__/expression-extractor.test.ts index 3e3238b0..9a7d03b7 100644 --- a/src/ast/__tests__/expression-extractor.test.ts +++ b/src/ast/__tests__/expression-extractor.test.ts @@ -1035,3 +1035,142 @@ describe("extract ClassInstanceCreationExpression correctly", () => { expect(ast).toEqual(expectedAst); }); }); + +describe("extract CastExpression correctly", () => { + it("extract CastExpression int to char correctly", () => { + const programStr = ` + class Test { + void test() { + char c = (char) 65; + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "LocalVariableDeclarationStatement", + localVariableType: "char", + variableDeclaratorList: [ + { + kind: "VariableDeclarator", + variableDeclaratorId: "c", + variableInitializer: { + kind: "CastExpression", + type: "char", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "65", + }, + location: expect.anything(), + }, + location: expect.anything(), + }, + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); + + it("extract CastExpression double to int correctly", () => { + const programStr = ` + class Test { + void test() { + int x = (int) 3.14; + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "LocalVariableDeclarationStatement", + localVariableType: "int", + variableDeclaratorList: [ + { + kind: "VariableDeclarator", + variableDeclaratorId: "x", + variableInitializer: { + kind: "CastExpression", + type: "int", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalFloatingPointLiteral", + value: "3.14", + } + }, + location: expect.anything(), + }, + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); +}); diff --git a/src/ast/astExtractor/expression-extractor.ts b/src/ast/astExtractor/expression-extractor.ts index 19782358..dcce959c 100644 --- a/src/ast/astExtractor/expression-extractor.ts +++ b/src/ast/astExtractor/expression-extractor.ts @@ -93,10 +93,9 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults { const primitiveCast = ctx.primitiveCastExpression[0]; const type = this.extractType(primitiveCast.children.primitiveType[0]); const expression = this.visit(primitiveCast.children.unaryExpression[0]); - console.debug({primitiveCast, type, expression}); return { kind: "CastExpression", - castType: type, + type: type, expression: expression, location: this.location, }; @@ -105,13 +104,41 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults { throw new Error("Invalid CastExpression format."); } - private extractType(typeCtx: any) { - if (typeCtx.Identifier) { - return typeCtx.Identifier[0].image; - } - if (typeCtx.unannPrimitiveType) { - return this.visit(typeCtx.unannPrimitiveType); + private extractType(typeCtx: any): string { + // Check for the 'primitiveType' node + if (typeCtx.name === "primitiveType" && typeCtx.children) { + const { children } = typeCtx; + + // Handle 'numericType' (e.g., int, char, float, double) + if (children.numericType) { + const numericTypeCtx = children.numericType[0]; + + if (numericTypeCtx.children.integralType) { + // Handle integral types (e.g., char, int) + const integralTypeCtx = numericTypeCtx.children.integralType[0]; + + // Extract the specific type (e.g., 'char', 'int') + for (const key in integralTypeCtx.children) { + if (integralTypeCtx.children[key][0].image) { + return integralTypeCtx.children[key][0].image; + } + } + } + + if (numericTypeCtx.children.floatingPointType) { + // Handle floating-point types (e.g., float, double) + const floatingPointTypeCtx = numericTypeCtx.children.floatingPointType[0]; + + // Extract the specific type (e.g., 'float', 'double') + for (const key in floatingPointTypeCtx.children) { + if (floatingPointTypeCtx.children[key][0].image) { + return floatingPointTypeCtx.children[key][0].image; + } + } + } + } } + throw new Error("Invalid type context in cast expression."); } @@ -204,6 +231,10 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults { } unaryExpression(ctx: UnaryExpressionCtx) { + if (ctx.primary[0].children.primaryPrefix[0].children.castExpression) { + return this.visit(ctx.primary[0].children.primaryPrefix[0].children.castExpression); + } + const node = this.visit(ctx.primary); if (ctx.UnaryPrefixOperator) { return { diff --git a/src/ast/types/blocks-and-statements.ts b/src/ast/types/blocks-and-statements.ts index fa2012c4..188a9584 100644 --- a/src/ast/types/blocks-and-statements.ts +++ b/src/ast/types/blocks-and-statements.ts @@ -119,7 +119,6 @@ export type Expression = | BinaryExpression | UnaryExpression | TernaryExpression - | CastExpression | Void; export interface Void extends BaseNode { @@ -260,7 +259,7 @@ export interface Assignment extends BaseNode { } export type LeftHandSide = ExpressionName | ArrayAccess; -export type UnaryExpression = PrefixExpression | PostfixExpression; +export type UnaryExpression = PrefixExpression | PostfixExpression | CastExpression; export interface PrefixExpression extends BaseNode { kind: "PrefixExpression"; diff --git a/src/types/checker/index.ts b/src/types/checker/index.ts index 005286e1..c9107e3e 100644 --- a/src/types/checker/index.ts +++ b/src/types/checker/index.ts @@ -1,7 +1,7 @@ import { Array as ArrayType } from '../types/arrays' import { Integer, String, Throwable, Void } from '../types/references' import { CaseConstant, Node } from '../ast/specificationTypes' -import { Type } from '../types/type' +import { PrimitiveType, Type } from '../types/type' import { ArrayRequiredError, BadOperandTypesError, @@ -63,6 +63,33 @@ export const check = (node: Node, frame: Frame = Frame.globalFrame()): Result => return typeCheckBody(node, typeCheckingFrame) } +const isCastCompatible = (fromType: Type, toType: Type): boolean => { + // Handle primitive type compatibility + if (fromType instanceof PrimitiveType && toType instanceof PrimitiveType) { + const fromName = fromType.constructor.name; + const toName = toType.constructor.name; + + console.log(fromName, toName); + + return !(fromName === 'char' && toName !== 'int'); + } + + // Handle class type compatibility + if (fromType instanceof ClassType && toType instanceof ClassType) { + // Allow upcasts (base class to derived class) or downcasts (derived class to base class) + return fromType.canBeAssigned(toType) || toType.canBeAssigned(fromType); + } + + // Handle array type compatibility + if (fromType instanceof ArrayType && toType instanceof ArrayType) { + // Ensure the content types are compatible + return isCastCompatible(fromType.getContentType(), toType.getContentType()); + } + + // Disallow other cases by default + return false; +}; + export const typeCheckBody = (node: Node, frame: Frame = Frame.globalFrame()): Result => { switch (node.kind) { case 'ArrayAccess': { @@ -192,6 +219,55 @@ export const typeCheckBody = (node: Node, frame: Frame = Frame.globalFrame()): R case 'BreakStatement': { return OK_RESULT } + + case 'CastExpression': { + let castType: Type | TypeCheckerError; + let expressionType: Type | null = null; + let expressionResult: Result; + + if ('primitiveType' in node) { + castType = frame.getType(unannTypeToString(node.primitiveType), node.primitiveType.location); + } else { + throw new Error('Invalid CastExpression: Missing type information.'); + } + + if (castType instanceof TypeCheckerError) { + return newResult(null, [castType]); + } + + if ('unaryExpression' in node) { + expressionResult = typeCheckBody(node.unaryExpression, frame); + } else { + throw new Error('Invalid CastExpression: Missing expression.'); + } + + if (expressionResult.hasErrors) { + return expressionResult; + } + + expressionType = expressionResult.currentType; + if (!expressionType) { + throw new Error('Expression in cast should have a type.'); + } + + if ( + (castType instanceof PrimitiveType && expressionType instanceof PrimitiveType) + ) { + if (!isCastCompatible(expressionType, castType)) { + return newResult(null, [ + new IncompatibleTypesError(node.location), + ]); + } + } else { + return newResult(null, [ + new IncompatibleTypesError(node.location), + ]); + } + + // If the cast is valid, return the target type + return newResult(castType); + } + case 'ClassInstanceCreationExpression': { const classIdentifier = node.unqualifiedClassInstanceCreationExpression.classOrInterfaceTypeToInstantiate From 345a8c83b59712937cef14204dd70644b467c1ed Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Mon, 17 Feb 2025 21:27:09 +0800 Subject: [PATCH 14/29] Add switch statement support to the AST extractor and the compiler --- .../switch-statement-extractor.test.ts | 435 ++++++++++++++++++ src/ast/astExtractor/statement-extractor.ts | 130 +++++- src/ast/types/blocks-and-statements.ts | 31 +- .../__tests__/__utils__/test-utils.ts | 1 - src/compiler/__tests__/index.ts | 54 +-- src/compiler/__tests__/tests/switch.test.ts | 178 +++++++ src/compiler/code-generator.ts | 248 +++++++++- src/compiler/grammar.pegjs | 37 +- src/compiler/grammar.ts | 37 +- 9 files changed, 1115 insertions(+), 36 deletions(-) create mode 100644 src/ast/__tests__/switch-statement-extractor.test.ts create mode 100644 src/compiler/__tests__/tests/switch.test.ts diff --git a/src/ast/__tests__/switch-statement-extractor.test.ts b/src/ast/__tests__/switch-statement-extractor.test.ts new file mode 100644 index 00000000..00ae21ea --- /dev/null +++ b/src/ast/__tests__/switch-statement-extractor.test.ts @@ -0,0 +1,435 @@ +import { parse } from "../parser"; +import { AST } from "../types/packages-and-modules"; + +describe("extract SwitchStatement correctly", () => { + it("extract SwitchStatement with case labels and statements correctly", () => { + const programStr = ` + class Test { + void test(int x) { + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + default: + System.out.println("Default"); + } + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [ + { + kind: "FormalParameter", + unannType: "int", + identifier: "x", + }, + ], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "SwitchStatement", + expression: { + kind: "ExpressionName", + name: "x", + location: expect.anything(), + }, + cases: [ + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "1", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"One"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "2", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"Two"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + { + kind: "SwitchCase", + labels: [ + { + kind: "DefaultLabel", + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"Default"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); + + it("extract SwitchStatement with fallthrough correctly", () => { + const programStr = ` + class Test { + void test(int x) { + switch (x) { + case 1: + case 2: + System.out.println("One or Two"); + break; + default: + System.out.println("Default"); + } + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [ + { + kind: "FormalParameter", + unannType: "int", + identifier: "x", + }, + ], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "SwitchStatement", + expression: { + kind: "ExpressionName", + name: "x", + location: expect.anything(), + }, + cases: [ + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "1", + }, + }, + }, + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "2", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"One or Two"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + { + kind: "SwitchCase", + labels: [ + { + kind: "DefaultLabel", + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"Default"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); + + it("extract SwitchStatement without default case correctly", () => { + const programStr = ` + class Test { + void test(int x) { + switch (x) { + case 1: + System.out.println("One"); + break; + } + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [ + { + kind: "FormalParameter", + unannType: "int", + identifier: "x", + }, + ], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "SwitchStatement", + expression: { + kind: "ExpressionName", + name: "x", + location: expect.anything(), + }, + cases: [ + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "1", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"One"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + console.log(JSON.stringify(ast, null, 2)); + expect(ast).toEqual(expectedAst); + }); +}); diff --git a/src/ast/astExtractor/statement-extractor.ts b/src/ast/astExtractor/statement-extractor.ts index ac2918cf..9e9f0e4c 100644 --- a/src/ast/astExtractor/statement-extractor.ts +++ b/src/ast/astExtractor/statement-extractor.ts @@ -21,6 +21,10 @@ import { PrimaryPrefixCtx, PrimarySuffixCtx, ReturnStatementCtx, + SwitchStatementCtx, + SwitchBlockCtx, + SwitchLabelCtx, + SwitchBlockStatementGroupCtx, StatementCstNode, StatementExpressionCtx, StatementWithoutTrailingSubstatementCtx, @@ -32,13 +36,17 @@ import { ExpressionStatementCtx, LocalVariableTypeCtx, VariableDeclaratorListCtx, - VariableDeclaratorCtx, -} from "java-parser"; + VariableDeclaratorCtx +} from 'java-parser' import { BasicForStatement, ExpressionStatement, IfStatement, MethodInvocation, + SwitchStatement, + SwitchCase, + CaseLabel, + DefaultLabel, Statement, StatementExpression, VariableDeclarator, @@ -80,6 +88,8 @@ export class StatementExtractor extends BaseJavaCstVisitorWithDefaults { return { kind: "BreakStatement" }; } else if (ctx.continueStatement) { return { kind: "ContinueStatement" }; + } else if (ctx.switchStatement) { + return this.visit(ctx.switchStatement); } else if (ctx.returnStatement) { const returnStatementExp = this.visit(ctx.returnStatement); return { @@ -90,6 +100,122 @@ export class StatementExtractor extends BaseJavaCstVisitorWithDefaults { } } + switchStatement(ctx: SwitchStatementCtx): SwitchStatement { + const expressionExtractor = new ExpressionExtractor(); + + return { + kind: "SwitchStatement", + expression: expressionExtractor.extract(ctx.expression[0]), + cases: ctx.switchBlock + ? this.visit(ctx.switchBlock) + : [], + location: ctx.Switch[0] + }; + } + + switchBlock(ctx: SwitchBlockCtx): Array { + const cases: Array = []; + let currentCase: SwitchCase; + + ctx.switchBlockStatementGroup?.forEach((group) => { + const extractedCase = this.visit(group); + + if (!currentCase) { + // First case in the switch block + currentCase = extractedCase; + cases.push(currentCase); + } else if (currentCase.statements && currentCase.statements.length === 0) { + // Fallthrough case, merge labels + currentCase.labels.push(...extractedCase.labels); + } else { + // New case with statements starts, push previous case and start new one + currentCase = extractedCase; + cases.push(currentCase); + } + }); + + return cases; + } + + switchBlockStatementGroup(ctx: SwitchBlockStatementGroupCtx): SwitchCase { + const blockStatementExtractor = new BlockStatementExtractor(); + + console.log(ctx.switchLabel) + + return { + kind: "SwitchCase", + labels: ctx.switchLabel.flatMap((label) => this.visit(label)), + statements: ctx.blockStatements + ? ctx.blockStatements.flatMap((blockStatements) => + blockStatements.children.blockStatement.map((stmt) => + blockStatementExtractor.extract(stmt) + ) + ) + : [], + }; + } + + // switchLabel(ctx: SwitchLabelCtx): CaseLabel | DefaultLabel { + // // Check if the context contains a "case" label + // if (ctx.caseOrDefaultLabel?.[0]?.children?.Case) { + // const expressionExtractor = new ExpressionExtractor(); + // // @ts-ignore + // const expressionCtx = ctx.caseOrDefaultLabel[0].children.caseLabelElement[0] + // .children.caseConstant[0].children.ternaryExpression[0].children; + // + // // Ensure the expression context is valid before proceeding + // if (!expressionCtx) { + // throw new Error("Invalid Case expression in switch label"); + // } + // + // const expression = expressionExtractor.ternaryExpression(expressionCtx); + // + // return { + // kind: "CaseLabel", + // expression: expression, + // }; + // } + // + // // Check if the context contains a "default" label + // if (ctx.caseOrDefaultLabel?.[0]?.children?.Default) { + // return { kind: "DefaultLabel" }; + // } + // + // // Throw an error if the context does not match expected patterns + // throw new Error("Invalid switch label: Neither 'case' nor 'default' found"); + // } + + switchLabel(ctx: SwitchLabelCtx): Array { + const expressionExtractor = new ExpressionExtractor(); + const labels: Array = []; + + // Process all case or default labels + for (const labelCtx of ctx.caseOrDefaultLabel) { + if (labelCtx.children.Case) { + // Extract the expression for the case label + const expressionCtx = labelCtx.children.caseLabelElement?.[0] + ?.children.caseConstant?.[0]?.children.ternaryExpression?.[0]?.children; + + if (!expressionCtx) { + throw new Error("Invalid Case expression in switch label"); + } + + labels.push({ + kind: "CaseLabel", + expression: expressionExtractor.ternaryExpression(expressionCtx), + }); + } else if (labelCtx.children.Default) { + labels.push({ kind: "DefaultLabel" }); + } + } + + if (labels.length === 0) { + throw new Error("Invalid switch label: Neither 'case' nor 'default' found"); + } + + return labels; + } + expressionStatement(ctx: ExpressionStatementCtx): ExpressionStatement { const stmtExp = this.visit(ctx.statementExpression); return { diff --git a/src/ast/types/blocks-and-statements.ts b/src/ast/types/blocks-and-statements.ts index 188a9584..54a1ce9d 100644 --- a/src/ast/types/blocks-and-statements.ts +++ b/src/ast/types/blocks-and-statements.ts @@ -28,7 +28,8 @@ export type Statement = | IfStatement | WhileStatement | ForStatement - | EmptyStatement; + | EmptyStatement + | SwitchStatement; export interface EmptyStatement extends BaseNode { kind: "EmptyStatement"; @@ -66,6 +67,34 @@ export interface EnhancedForStatement extends BaseNode { kind: "EnhancedForStatement"; } +export interface SwitchStatement extends BaseNode { + kind: "SwitchStatement"; + expression: Expression; // The expression to evaluate for the switch + cases: Array; +} + +export interface SwitchCase extends BaseNode { + kind: "SwitchCase"; + labels: Array; // Labels for case blocks + statements?: Array; // Statements to execute for the case +} + +export type CaseLabel = CaseLiteralLabel | CaseExpressionLabel; + +export interface CaseLiteralLabel extends BaseNode { + kind: "CaseLabel"; + expression: Literal; // Literal values: byte, short, int, char, or String +} + +export interface CaseExpressionLabel extends BaseNode { + kind: "CaseLabel"; + expression: Expression; // For future extension if needed +} + +export interface DefaultLabel extends BaseNode { + kind: "DefaultLabel"; // Represents the default case +} + export type StatementWithoutTrailingSubstatement = | Block | ExpressionStatement diff --git a/src/compiler/__tests__/__utils__/test-utils.ts b/src/compiler/__tests__/__utils__/test-utils.ts index 042ca8ef..36d11f57 100644 --- a/src/compiler/__tests__/__utils__/test-utils.ts +++ b/src/compiler/__tests__/__utils__/test-utils.ts @@ -24,7 +24,6 @@ const binaryWriter = new BinaryWriter(); export function runTest(program: string, expectedLines: string[]) { const ast = parser.parse(program); - console.log(JSON.stringify(ast, null, 2) + "\n"); expect(ast).not.toBeNull(); if (debug) { diff --git a/src/compiler/__tests__/index.ts b/src/compiler/__tests__/index.ts index e8bf633b..269d5d8f 100644 --- a/src/compiler/__tests__/index.ts +++ b/src/compiler/__tests__/index.ts @@ -1,29 +1,31 @@ -import { printlnTest } from "./tests/println.test"; -import { variableDeclarationTest } from "./tests/variableDeclaration.test"; -import { arithmeticExpressionTest } from "./tests/arithmeticExpression.test"; -import { ifElseTest } from "./tests/ifElse.test"; -import { whileTest } from "./tests/while.test"; -import { forTest } from "./tests/for.test"; -import { unaryExpressionTest } from "./tests/unaryExpression.test"; -import { methodInvocationTest } from "./tests/methodInvocation.test"; -import { importTest } from "./tests/import.test"; -import { arrayTest } from "./tests/array.test"; -import { classTest } from "./tests/class.test"; +import { printlnTest } from './tests/println.test' +import { variableDeclarationTest } from './tests/variableDeclaration.test' +import { arithmeticExpressionTest } from './tests/arithmeticExpression.test' +import { ifElseTest } from './tests/ifElse.test' +import { whileTest } from './tests/while.test' +import { forTest } from './tests/for.test' +import { unaryExpressionTest } from './tests/unaryExpression.test' +import { methodInvocationTest } from './tests/methodInvocation.test' +import { importTest } from './tests/import.test' +import { arrayTest } from './tests/array.test' +import { classTest } from './tests/class.test' import { assignmentExpressionTest } from './tests/assignmentExpression.test' import { castExpressionTest } from './tests/castExpression.test' +import { switchTest } from './tests/switch.test' -describe("compiler tests", () => { - castExpressionTest(); - printlnTest(); - variableDeclarationTest(); - arithmeticExpressionTest(); - unaryExpressionTest(); - ifElseTest(); - whileTest(); - forTest(); - methodInvocationTest(); - importTest(); - arrayTest(); - classTest(); - assignmentExpressionTest(); -}) \ No newline at end of file +describe('compiler tests', () => { + switchTest() + castExpressionTest() + printlnTest() + variableDeclarationTest() + arithmeticExpressionTest() + unaryExpressionTest() + ifElseTest() + whileTest() + forTest() + methodInvocationTest() + importTest() + arrayTest() + classTest() + assignmentExpressionTest() +}) diff --git a/src/compiler/__tests__/tests/switch.test.ts b/src/compiler/__tests__/tests/switch.test.ts new file mode 100644 index 00000000..8ba6ac82 --- /dev/null +++ b/src/compiler/__tests__/tests/switch.test.ts @@ -0,0 +1,178 @@ +import { runTest, testCase } from '../__utils__/test-utils' + +const testCases: testCase[] = [ + { + comment: 'More basic switch case', + program: ` + public class Basic { + public static void main(String[] args) { + int x = 1; + switch (x) { + case 1: + System.out.println("One"); + break; + } + } + } + `, + expectedLines: ['One'] + }, + { + comment: 'Basic switch case', + program: ` + public class Main { + public static void main(String[] args) { + int x = 2; + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + case 3: + System.out.println("Three"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Two'] + }, + { + comment: 'Switch with default case', + program: ` + public class Main { + public static void main(String[] args) { + int x = 5; + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Default'] + }, + { + comment: 'Switch fallthrough behavior', + program: ` + public class Main { + public static void main(String[] args) { + int x = 2; + switch (x) { + case 1: + System.out.println("One"); + case 2: + System.out.println("Two"); + case 3: + System.out.println("Three"); + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Two', 'Three', 'Default'] + }, + { + comment: 'Switch with break statements', + program: ` + public class Main { + public static void main(String[] args) { + int x = 3; + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + case 3: + System.out.println("Three"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Three'] + }, + { + comment: 'Switch with strings', + program: ` + public class Main { + public static void main(String[] args) { + String day = "Tuesday"; + switch (day) { + case "Monday": + System.out.println("Start of the week"); + break; + case "Tuesday": + System.out.println("Second day"); + break; + case "Friday": + System.out.println("Almost weekend"); + break; + default: + System.out.println("Midweek or weekend"); + } + } + } + `, + expectedLines: ['Second day'] + }, + { + comment: 'Nested switch statements', + program: ` + public class Main { + public static void main(String[] args) { + int outer = 2; + int inner = 1; + switch (outer) { + case 1: + switch (inner) { + case 1: + System.out.println("Inner One"); + break; + case 2: + System.out.println("Inner Two"); + break; + } + break; + case 2: + switch (inner) { + case 1: + System.out.println("Outer Two, Inner One"); + break; + case 2: + System.out.println("Outer Two, Inner Two"); + break; + } + break; + default: + System.out.println("Default case"); + } + } + } + `, + expectedLines: ['Outer Two, Inner One'] + } +] + +export const switchTest = () => + describe('Switch statements', () => { + for (let testCase of testCases) { + const { comment, program, expectedLines } = testCase + it(comment, () => runTest(program, expectedLines)) + } + }) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 5f76f4a8..73365638 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -24,7 +24,7 @@ import { ClassInstanceCreationExpression, ExpressionStatement, TernaryExpression, - LeftHandSide, CastExpression + LeftHandSide, CastExpression, SwitchStatement, SwitchCase, CaseLabel } from '../ast/types/blocks-and-statements' import { MethodDeclaration, UnannType } from '../ast/types/classes' import { ConstantPoolManager } from './constant-pool-manager' @@ -272,6 +272,14 @@ function generateStringConversion(valueType: string, cg: CodeGenerator): void { cg.code.push(OPCODE.INVOKESTATIC, 0, methodIndex); } +function hashCode(str: string): number { + let hash = 0; + for (let i = 0; i < str.length; i++) { + hash = (hash * 31 + str.charCodeAt(i)) | 0; // Simulate Java's overflow behavior + } + return hash; +} + // function generateBooleanConversion(type: string, cg: CodeGenerator): number { // let stackChange = 0; // Tracks changes to the stack size // @@ -437,8 +445,16 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi }, BreakStatement: (node: Node, cg: CodeGenerator) => { - cg.addBranchInstr(OPCODE.GOTO, cg.loopLabels[cg.loopLabels.length - 1][1]) - return { stackSize: 0, resultType: EMPTY_TYPE } + if (cg.loopLabels.length > 0) { + // If inside a loop, break jumps to the end of the loop + cg.addBranchInstr(OPCODE.GOTO, cg.loopLabels[cg.loopLabels.length - 1][1]); + } else if (cg.switchLabels.length > 0) { + // If inside a switch, break jumps to the switch's end label + cg.addBranchInstr(OPCODE.GOTO, cg.switchLabels[cg.switchLabels.length - 1]); + } else { + throw new Error("Break statement not inside a loop or switch statement"); + } + return { stackSize: 0, resultType: EMPTY_TYPE }; }, ContinueStatement: (node: Node, cg: CodeGenerator) => { @@ -1280,7 +1296,230 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi stackSize, resultType: cg.symbolTable.generateFieldDescriptor(type), } -} + }, + + SwitchStatement: (node: Node, cg: CodeGenerator) => { + const { expression, cases } = node as SwitchStatement; + + // Compile the switch expression + const { stackSize: exprStackSize, resultType } = compile(expression, cg); + let maxStack = exprStackSize; + + const caseLabels: Label[] = cases.map(() => cg.generateNewLabel()); + const defaultLabel = cg.generateNewLabel(); + const endLabel = cg.generateNewLabel(); + + // Track the switch statement's end label + cg.switchLabels.push(endLabel); + + if (["I", "B", "S", "C"].includes(resultType)) { + const caseValues: number[] = []; + const caseLabelMap: Map = new Map(); + let hasDefault = false; + + cases.forEach((caseGroup, index) => { + caseGroup.labels.forEach((label) => { + if (label.kind === "CaseLabel") { + const value = parseInt((label.expression as Literal).literalType.value); + caseValues.push(value); + caseLabelMap.set(value, caseLabels[index]); + } else if (label.kind === "DefaultLabel") { + caseLabels[index] = defaultLabel; + hasDefault = true; + } + }); + }); + + const [minValue, maxValue] = [Math.min(...caseValues), Math.max(...caseValues)]; + const useTableSwitch = maxValue - minValue < caseValues.length * 2; + + if (useTableSwitch) { + cg.code.push(OPCODE.TABLESWITCH); + + // Ensure 4-byte alignment for TABLESWITCH + while (cg.code.length % 4 !== 0) { + cg.code.push(0); // Padding bytes (JVM requires alignment) + } + + // Add default branch (jump to default label) + cg.code.push(0, 0, 0, defaultLabel.offset); + + // Push low and high values (min and max case values) + cg.code.push(minValue >> 24, (minValue >> 16) & 0xff, (minValue >> 8) & 0xff, minValue & 0xff); + cg.code.push(maxValue >> 24, (maxValue >> 16) & 0xff, (maxValue >> 8) & 0xff, maxValue & 0xff); + + // Generate branch table (map each value to a case label) + for (let i = minValue; i <= maxValue; i++) { + const caseIndex = caseValues.indexOf(i); + cg.code.push(0, 0, 0, caseIndex !== -1 ? caseLabels[caseIndex].offset + : defaultLabel.offset); + } + } else { + cg.code.push(OPCODE.LOOKUPSWITCH); + + // Ensure 4-byte alignment for LOOKUPSWITCH + while (cg.code.length % 4 !== 0) { + cg.code.push(0); + } + + // Add default branch (jump to default label) + cg.code.push(0, 0, 0, defaultLabel.offset); + + // Push the number of case-value pairs + cg.code.push((caseValues.length >> 24) & 0xff, (caseValues.length >> 16) & 0xff, + (caseValues.length >> 8) & 0xff, caseValues.length & 0xff); + + // Generate lookup table (pairs of case values and corresponding labels) + caseValues.forEach((value, index) => { + cg.code.push(value >> 24, (value >> 16) & 0xff, (value >> 8) & 0xff, value & 0xff); + cg.code.push(0, 0, 0, caseLabels[index].offset); + }); + } + + // **Process case bodies with proper fallthrough handling** + let previousCase: SwitchCase | null = null; + + cases.forEach((caseGroup, index) => { + caseLabels[index].offset = cg.code.length; + + // Ensure statements array is always defined + caseGroup.statements = caseGroup.statements || []; + + // If previous case had no statements, merge labels (fallthrough) + if (previousCase && (previousCase.statements?.length ?? 0) === 0) { + previousCase.labels.push(...caseGroup.labels); + } + + // Compile case statements + caseGroup.statements.forEach((statement) => { + const { stackSize } = compile(statement, cg); + maxStack = Math.max(maxStack, stackSize); + }); + + // Add jump to the end label if the case has statements + if (caseGroup.statements.length > 0) { + cg.addBranchInstr(OPCODE.GOTO, endLabel); + } + + previousCase = caseGroup; + }); + + // **Process default case** + defaultLabel.offset = cg.code.length; + if (hasDefault) { + const defaultCase = cases.find((caseGroup) => + caseGroup.labels.some((label) => label.kind === "DefaultLabel") + ); + if (defaultCase) { + defaultCase.statements = defaultCase.statements || []; + defaultCase.statements.forEach((statement) => { + const { stackSize } = compile(statement, cg); + maxStack = Math.max(maxStack, stackSize); + }); + } + } + + endLabel.offset = cg.code.length; + + } else if (resultType === "Ljava/lang/String;") { + // **String cases** + const hashCaseMap: Map = new Map(); + const hashCodeVarIndex = cg.maxLocals++; + + // Generate `hashCode()` call + cg.code.push( + OPCODE.INVOKEVIRTUAL, + 0, + cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "hashCode", "()I") + ); + cg.code.push(OPCODE.ISTORE, hashCodeVarIndex); + + cases.forEach((caseGroup, index) => { + caseGroup.labels.forEach((label) => { + if (label.kind === "CaseLabel") { + const caseValue = (label.expression as Literal).literalType.value; + const hashCodeValue = hashCode(caseValue); + if (!hashCaseMap.has(hashCodeValue)) { + hashCaseMap.set(hashCodeValue, caseLabels[index]); + } + } else if (label.kind === "DefaultLabel") { + caseLabels[index] = defaultLabel; + } + }); + }); + + // **Compare hashCodes** + hashCaseMap.forEach((label, hashCode) => { + cg.code.push(OPCODE.ILOAD, hashCodeVarIndex); + cg.code.push(OPCODE.BIPUSH, hashCode); + cg.addBranchInstr(OPCODE.IF_ICMPEQ, label); + }); + + cg.addBranchInstr(OPCODE.GOTO, defaultLabel); + + // **Process case bodies** + let previousCase: SwitchCase | null = null; + + cases.forEach((caseGroup, index) => { + caseLabels[index].offset = cg.code.length; + + // Ensure statements array is always defined + caseGroup.statements = caseGroup.statements || []; + + // Handle fallthrough + if (previousCase && (previousCase.statements?.length ?? 0) === 0) { + previousCase.labels.push(...caseGroup.labels); + } + + // Generate string comparison + const caseValue = caseGroup.labels.find((label): label is CaseLabel => label.kind === "CaseLabel"); + if (caseValue) { + const caseStr = (caseValue.expression as Literal).literalType.value; + const caseStrIndex = cg.constantPoolManager.indexStringInfo(caseStr); + cg.code.push(OPCODE.LDC, caseStrIndex); + cg.code.push( + OPCODE.INVOKEVIRTUAL, + 0, + cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "equals", "(Ljava/lang/Object;)Z") + ); + const caseEndLabel = cg.generateNewLabel(); + cg.addBranchInstr(OPCODE.IFEQ, caseEndLabel); + + // Compile case statements + caseGroup.statements.forEach((statement) => { + const { stackSize } = compile(statement, cg); + maxStack = Math.max(maxStack, stackSize); + }); + + cg.addBranchInstr(OPCODE.GOTO, endLabel); + caseEndLabel.offset = cg.code.length; + } + + previousCase = caseGroup; + }); + + // **Process default case** + defaultLabel.offset = cg.code.length; + const defaultCase = cases.find((caseGroup) => + caseGroup.labels.some((label) => label.kind === "DefaultLabel") + ); + if (defaultCase) { + defaultCase.statements = defaultCase.statements || []; + defaultCase.statements.forEach((statement) => { + const { stackSize } = compile(statement, cg); + maxStack = Math.max(maxStack, stackSize); + }); + } + + endLabel.offset = cg.code.length; + } else { + throw new Error(`Switch statements only support byte, short, int, char, or String types. Found: ${resultType}`); + } + + cg.switchLabels.pop(); + + return { stackSize: maxStack, resultType: EMPTY_TYPE }; + } } class CodeGenerator { @@ -1290,6 +1529,7 @@ class CodeGenerator { stackSize: number = 0 labels: Label[] = [] loopLabels: Label[][] = [] + switchLabels: Label[] = [] code: number[] = [] constructor(symbolTable: SymbolTable, constantPoolManager: ConstantPoolManager) { diff --git a/src/compiler/grammar.pegjs b/src/compiler/grammar.pegjs index 8f97b417..505f648e 100755 --- a/src/compiler/grammar.pegjs +++ b/src/compiler/grammar.pegjs @@ -774,7 +774,42 @@ AssertStatement = assert Expression (colon Expression) semicolon SwitchStatement - = TO_BE_ADDED + = switch lparen expr:Expression rparen lcurly + cases:SwitchBlock? + rcurly { + return addLocInfo({ + kind: "SwitchStatement", + expression: expr, + cases: cases ?? [], + }); + } + +SwitchBlock + = cases:SwitchBlockStatementGroup* { + return cases; + } + +SwitchBlockStatementGroup + = labels:SwitchLabel+ stmts:BlockStatement* { + return { + kind: "SwitchBlockStatementGroup", + labels: labels, + statements: stmts, + }; + } + +SwitchLabel + = case expr:Expression colon { + return { + kind: "CaseLabel", + expression: expr, + }; + } + / default colon { + return { + kind: "DefaultLabel", + }; + } DoStatement = do body:Statement while lparen expr:Expression rparen semicolon { diff --git a/src/compiler/grammar.ts b/src/compiler/grammar.ts index 4c776ffd..57470926 100755 --- a/src/compiler/grammar.ts +++ b/src/compiler/grammar.ts @@ -776,7 +776,42 @@ AssertStatement = assert Expression (colon Expression) semicolon SwitchStatement - = TO_BE_ADDED + = switch lparen expr:Expression rparen lcurly + cases:SwitchBlock? + rcurly { + return addLocInfo({ + kind: "SwitchStatement", + expression: expr, + cases: cases ?? [], + }); + } + +SwitchBlock + = cases:SwitchBlockStatementGroup* { + return cases; + } + +SwitchBlockStatementGroup + = labels:SwitchLabel+ stmts:BlockStatement* { + return { + kind: "SwitchBlockStatementGroup", + labels: labels, + statements: stmts, + }; + } + +SwitchLabel + = case expr:Expression colon { + return { + kind: "CaseLabel", + expression: expr, + }; + } + / default colon { + return { + kind: "DefaultLabel", + }; + } DoStatement = do body:Statement while lparen expr:Expression rparen semicolon { From 2038d8b0a1c12beb42f92f99dd10d39c0d474581 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 18 Feb 2025 10:20:55 +0800 Subject: [PATCH 15/29] Add switch statement support to the compiler for integer like types (int, byte, short, char) --- src/compiler/__tests__/tests/switch.test.ts | 2 +- src/compiler/code-generator.ts | 53 +++++++++++++++------ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/src/compiler/__tests__/tests/switch.test.ts b/src/compiler/__tests__/tests/switch.test.ts index 8ba6ac82..919002ae 100644 --- a/src/compiler/__tests__/tests/switch.test.ts +++ b/src/compiler/__tests__/tests/switch.test.ts @@ -4,7 +4,7 @@ const testCases: testCase[] = [ { comment: 'More basic switch case', program: ` - public class Basic { + public class Main { public static void main(String[] args) { int x = 1; switch (x) { diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 73365638..bfe690a4 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -204,7 +204,6 @@ function areClassTypesCompatible(fromType: string, toType: string): boolean { } function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator): number { - console.debug(`Converting from: ${fromType}, to: ${toType}`) if (fromType === toType || toType.replace(/^L|;$/g, '') === 'java/lang/String') { return 0; } @@ -968,8 +967,6 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi if (op === '+' && (leftType === 'Ljava/lang/String;' || rightType === 'Ljava/lang/String;')) { - console.debug(`String concatenation detected: ${leftType} ${op} ${rightType}`) - if (leftType !== 'Ljava/lang/String;') { generateStringConversion(leftType, cg); } @@ -995,10 +992,6 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi let finalType = leftType; if (leftType !== rightType) { - console.debug( - `Type mismatch detected: leftType=${leftType}, rightType=${rightType}. Applying implicit conversions.` - ); - const conversionKeyLeft = `${leftType}->${rightType}` const conversionKeyRight = `${rightType}->${leftType}` @@ -1308,6 +1301,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const caseLabels: Label[] = cases.map(() => cg.generateNewLabel()); const defaultLabel = cg.generateNewLabel(); const endLabel = cg.generateNewLabel(); + const positionOffset = cg.code.length; // Track the switch statement's end label cg.switchLabels.push(endLabel); @@ -1332,54 +1326,73 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const [minValue, maxValue] = [Math.min(...caseValues), Math.max(...caseValues)]; const useTableSwitch = maxValue - minValue < caseValues.length * 2; + const caseLabelIndex: number[] = [] + let indexTracker = cg.code.length; if (useTableSwitch) { cg.code.push(OPCODE.TABLESWITCH); + indexTracker++ // Ensure 4-byte alignment for TABLESWITCH while (cg.code.length % 4 !== 0) { cg.code.push(0); // Padding bytes (JVM requires alignment) + indexTracker++ } // Add default branch (jump to default label) cg.code.push(0, 0, 0, defaultLabel.offset); + caseLabelIndex.push(indexTracker + 3); + indexTracker += 4 // Push low and high values (min and max case values) cg.code.push(minValue >> 24, (minValue >> 16) & 0xff, (minValue >> 8) & 0xff, minValue & 0xff); cg.code.push(maxValue >> 24, (maxValue >> 16) & 0xff, (maxValue >> 8) & 0xff, maxValue & 0xff); + indexTracker += 8 // Generate branch table (map each value to a case label) for (let i = minValue; i <= maxValue; i++) { const caseIndex = caseValues.indexOf(i); cg.code.push(0, 0, 0, caseIndex !== -1 ? caseLabels[caseIndex].offset : defaultLabel.offset); + caseLabelIndex.push(indexTracker + 3); + indexTracker += 4; } } else { cg.code.push(OPCODE.LOOKUPSWITCH); + indexTracker++; // Ensure 4-byte alignment for LOOKUPSWITCH while (cg.code.length % 4 !== 0) { cg.code.push(0); + indexTracker++ } // Add default branch (jump to default label) cg.code.push(0, 0, 0, defaultLabel.offset); + caseLabelIndex.push(indexTracker + 3); + indexTracker += 4 // Push the number of case-value pairs cg.code.push((caseValues.length >> 24) & 0xff, (caseValues.length >> 16) & 0xff, (caseValues.length >> 8) & 0xff, caseValues.length & 0xff); + indexTracker += 4 // Generate lookup table (pairs of case values and corresponding labels) caseValues.forEach((value, index) => { cg.code.push(value >> 24, (value >> 16) & 0xff, (value >> 8) & 0xff, value & 0xff); cg.code.push(0, 0, 0, caseLabels[index].offset); + caseLabelIndex.push(indexTracker + 7); + indexTracker += 8; }); } // **Process case bodies with proper fallthrough handling** let previousCase: SwitchCase | null = null; - cases.forEach((caseGroup, index) => { + const nonDefaultCases = cases.filter((caseGroup) => + caseGroup.labels.some((label) => label.kind === "CaseLabel")) + + nonDefaultCases.forEach((caseGroup, index) => { caseLabels[index].offset = cg.code.length; // Ensure statements array is always defined @@ -1396,11 +1409,6 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi maxStack = Math.max(maxStack, stackSize); }); - // Add jump to the end label if the case has statements - if (caseGroup.statements.length > 0) { - cg.addBranchInstr(OPCODE.GOTO, endLabel); - } - previousCase = caseGroup; }); @@ -1419,6 +1427,17 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } } + cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset; + + for (let i = 1; i < caseLabelIndex.length; i++) { + cg.code[caseLabelIndex[i]] = caseLabels[i - 1].offset - positionOffset + } + + console.debug(cg.code) + + console.debug(caseLabels[caseLabels.length - 1]) + console.debug(caseLabels.splice(0, caseLabels.length - 1)) + endLabel.offset = cg.code.length; } else if (resultType === "Ljava/lang/String;") { @@ -1432,6 +1451,8 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi 0, cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "hashCode", "()I") ); + + // const switchCaseExpressionHash = hashCode(); cg.code.push(OPCODE.ISTORE, hashCodeVarIndex); cases.forEach((caseGroup, index) => { @@ -1460,7 +1481,10 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi // **Process case bodies** let previousCase: SwitchCase | null = null; - cases.forEach((caseGroup, index) => { + const nonDefaultCases = cases.filter((caseGroup) => + caseGroup.labels.some((label) => label.kind === "CaseLabel")) + + nonDefaultCases.forEach((caseGroup, index) => { caseLabels[index].offset = cg.code.length; // Ensure statements array is always defined @@ -1491,7 +1515,6 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi maxStack = Math.max(maxStack, stackSize); }); - cg.addBranchInstr(OPCODE.GOTO, endLabel); caseEndLabel.offset = cg.code.length; } From 91691ad82c13490855db6b0aac37e787b6441df1 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 18 Feb 2025 16:36:13 +0800 Subject: [PATCH 16/29] Add switch statement support to the compiler for String type --- src/compiler/__tests__/tests/switch.test.ts | 25 ++++ src/compiler/code-generator.ts | 143 ++++++++++++-------- 2 files changed, 110 insertions(+), 58 deletions(-) diff --git a/src/compiler/__tests__/tests/switch.test.ts b/src/compiler/__tests__/tests/switch.test.ts index 919002ae..003e9ff4 100644 --- a/src/compiler/__tests__/tests/switch.test.ts +++ b/src/compiler/__tests__/tests/switch.test.ts @@ -166,6 +166,31 @@ const testCases: testCase[] = [ } `, expectedLines: ['Outer Two, Inner One'] + }, + + { + comment: 'Switch with far apart cases', + program: ` + public class Main { + public static void main(String[] args) { + int x = 1331; + switch (x) { + case 1: + System.out.println("No"); + break; + case 1331: + System.out.println("Yes"); + break; + case 999999999: + System.out.println("No"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Yes'] } ] diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index bfe690a4..a59ba776 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -274,7 +274,7 @@ function generateStringConversion(valueType: string, cg: CodeGenerator): void { function hashCode(str: string): number { let hash = 0; for (let i = 0; i < str.length; i++) { - hash = (hash * 31 + str.charCodeAt(i)) | 0; // Simulate Java's overflow behavior + hash = ((hash * 31) + str.charCodeAt(i)); // Simulate Java's overflow behavior } return hash; } @@ -1301,7 +1301,6 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const caseLabels: Label[] = cases.map(() => cg.generateNewLabel()); const defaultLabel = cg.generateNewLabel(); const endLabel = cg.generateNewLabel(); - const positionOffset = cg.code.length; // Track the switch statement's end label cg.switchLabels.push(endLabel); @@ -1310,6 +1309,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const caseValues: number[] = []; const caseLabelMap: Map = new Map(); let hasDefault = false; + const positionOffset = cg.code.length; cases.forEach((caseGroup, index) => { caseGroup.labels.forEach((label) => { @@ -1433,33 +1433,25 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.code[caseLabelIndex[i]] = caseLabels[i - 1].offset - positionOffset } - console.debug(cg.code) - - console.debug(caseLabels[caseLabels.length - 1]) - console.debug(caseLabels.splice(0, caseLabels.length - 1)) - endLabel.offset = cg.code.length; } else if (resultType === "Ljava/lang/String;") { - // **String cases** + // **String Switch Handling** const hashCaseMap: Map = new Map(); - const hashCodeVarIndex = cg.maxLocals++; - // Generate `hashCode()` call + // Compute and store hashCode() cg.code.push( OPCODE.INVOKEVIRTUAL, 0, cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "hashCode", "()I") ); - // const switchCaseExpressionHash = hashCode(); - cg.code.push(OPCODE.ISTORE, hashCodeVarIndex); - + // Create lookup table for hashCodes cases.forEach((caseGroup, index) => { caseGroup.labels.forEach((label) => { if (label.kind === "CaseLabel") { const caseValue = (label.expression as Literal).literalType.value; - const hashCodeValue = hashCode(caseValue); + const hashCodeValue = hashCode(caseValue.slice(1, caseValue.length - 1)); if (!hashCaseMap.has(hashCodeValue)) { hashCaseMap.set(hashCodeValue, caseLabels[index]); } @@ -1469,63 +1461,91 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi }); }); - // **Compare hashCodes** - hashCaseMap.forEach((label, hashCode) => { - cg.code.push(OPCODE.ILOAD, hashCodeVarIndex); - cg.code.push(OPCODE.BIPUSH, hashCode); - cg.addBranchInstr(OPCODE.IF_ICMPEQ, label); - }); + const caseLabelIndex: number[] = [] + let indexTracker = cg.code.length; + const positionOffset = cg.code.length; - cg.addBranchInstr(OPCODE.GOTO, defaultLabel); + // **LOOKUPSWITCH Implementation** + cg.code.push(OPCODE.LOOKUPSWITCH); + indexTracker++ - // **Process case bodies** - let previousCase: SwitchCase | null = null; + // Ensure 4-byte alignment + while (cg.code.length % 4 !== 0) { + cg.code.push(0); + indexTracker++ + } - const nonDefaultCases = cases.filter((caseGroup) => - caseGroup.labels.some((label) => label.kind === "CaseLabel")) + // Default jump target + cg.code.push(0, 0, 0, defaultLabel.offset); + caseLabelIndex.push(indexTracker + 3); + indexTracker += 4; - nonDefaultCases.forEach((caseGroup, index) => { - caseLabels[index].offset = cg.code.length; - // Ensure statements array is always defined - caseGroup.statements = caseGroup.statements || []; + // Number of case-value pairs + cg.code.push((hashCaseMap.size >> 24) & 0xff, (hashCaseMap.size >> 16) & 0xff, + (hashCaseMap.size >> 8) & 0xff, hashCaseMap.size & 0xff); + indexTracker += 4; - // Handle fallthrough - if (previousCase && (previousCase.statements?.length ?? 0) === 0) { - previousCase.labels.push(...caseGroup.labels); - } + // Populate LOOKUPSWITCH + hashCaseMap.forEach((label, hashCode) => { + cg.code.push(hashCode >> 24, (hashCode >> 16) & 0xff, (hashCode >> 8) & 0xff, hashCode & 0xff); + cg.code.push(0, 0, 0, label.offset); + caseLabelIndex.push(indexTracker + 7); + indexTracker += 8; + }); - // Generate string comparison - const caseValue = caseGroup.labels.find((label): label is CaseLabel => label.kind === "CaseLabel"); - if (caseValue) { - const caseStr = (caseValue.expression as Literal).literalType.value; - const caseStrIndex = cg.constantPoolManager.indexStringInfo(caseStr); - cg.code.push(OPCODE.LDC, caseStrIndex); - cg.code.push( - OPCODE.INVOKEVIRTUAL, - 0, - cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "equals", "(Ljava/lang/Object;)Z") - ); - const caseEndLabel = cg.generateNewLabel(); - cg.addBranchInstr(OPCODE.IFEQ, caseEndLabel); + // **Case Handling** + let previousCase: SwitchCase | null = null; - // Compile case statements - caseGroup.statements.forEach((statement) => { - const { stackSize } = compile(statement, cg); - maxStack = Math.max(maxStack, stackSize); - }); + cases.filter((caseGroup) => + caseGroup.labels.some((label) => label.kind === "CaseLabel")) + .forEach((caseGroup, index) => { + caseLabels[index].offset = cg.code.length; - caseEndLabel.offset = cg.code.length; - } + // Ensure statements exist + caseGroup.statements = caseGroup.statements || []; - previousCase = caseGroup; - }); + // Handle fallthrough + if (previousCase && (previousCase.statements?.length ?? 0) === 0) { + previousCase.labels.push(...caseGroup.labels); + } - // **Process default case** + // **String Comparison for Collisions** + const caseValue = caseGroup.labels.find((label): label is CaseLabel => label.kind === "CaseLabel"); + if (caseValue) { + // TODO: check for actual String equality instead of just rely on hashCode equality + // (see the commented out code below) + + // const caseStr = (caseValue.expression as Literal).literalType.value; + // const caseStrIndex = cg.constantPoolManager.indexStringInfo(caseStr); + + // cg.code.push(OPCODE.LDC, caseStrIndex); + // cg.code.push( + // OPCODE.INVOKEVIRTUAL, + // 0, + // cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "equals", "(Ljava/lang/Object;)Z") + // ); + // + const caseEndLabel = cg.generateNewLabel(); + // cg.addBranchInstr(OPCODE.IFEQ, caseEndLabel); + + // Compile case statements + caseGroup.statements.forEach((statement) => { + const { stackSize } = compile(statement, cg); + maxStack = Math.max(maxStack, stackSize); + }); + + caseEndLabel.offset = cg.code.length; + } + + previousCase = caseGroup; + }); + + // **Default Case Handling** defaultLabel.offset = cg.code.length; const defaultCase = cases.find((caseGroup) => - caseGroup.labels.some((label) => label.kind === "DefaultLabel") - ); + caseGroup.labels.some((label) => label.kind === "DefaultLabel")); + if (defaultCase) { defaultCase.statements = defaultCase.statements || []; defaultCase.statements.forEach((statement) => { @@ -1534,7 +1554,14 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi }); } + cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset; + + for (let i = 1; i < caseLabelIndex.length; i++) { + cg.code[caseLabelIndex[i]] = caseLabels[i - 1].offset - positionOffset + } + endLabel.offset = cg.code.length; + } else { throw new Error(`Switch statements only support byte, short, int, char, or String types. Found: ${resultType}`); } From 54c0e059c33744e8f8d75ca7d821b34b634e159c Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 4 Mar 2025 09:26:05 +0800 Subject: [PATCH 17/29] Add overloading support to the compiler --- src/compiler/__tests__/index.ts | 2 + .../__tests__/tests/methodOverloading.test.ts | 192 ++++++++++++++++++ src/compiler/code-generator.ts | 119 ++++++----- 3 files changed, 263 insertions(+), 50 deletions(-) create mode 100644 src/compiler/__tests__/tests/methodOverloading.test.ts diff --git a/src/compiler/__tests__/index.ts b/src/compiler/__tests__/index.ts index 269d5d8f..ba44159d 100644 --- a/src/compiler/__tests__/index.ts +++ b/src/compiler/__tests__/index.ts @@ -12,8 +12,10 @@ import { classTest } from './tests/class.test' import { assignmentExpressionTest } from './tests/assignmentExpression.test' import { castExpressionTest } from './tests/castExpression.test' import { switchTest } from './tests/switch.test' +import { methodOverloadingTest } from './tests/methodOverloading.test' describe('compiler tests', () => { + methodOverloadingTest() switchTest() castExpressionTest() printlnTest() diff --git a/src/compiler/__tests__/tests/methodOverloading.test.ts b/src/compiler/__tests__/tests/methodOverloading.test.ts new file mode 100644 index 00000000..5d7322be --- /dev/null +++ b/src/compiler/__tests__/tests/methodOverloading.test.ts @@ -0,0 +1,192 @@ +import { + runTest, + testCase, +} from "../__utils__/test-utils"; + +const testCases: testCase[] = [ + { + comment: "Basic method overloading", + program: ` + public class Main { + public static void f(int x) { + System.out.println("int: " + x); + } + public static void f(double x) { + System.out.println("double: " + x); + } + public static void main(String[] args) { + f(5); + f(5.5); + } + } + `, + expectedLines: ["int: 5", "double: 5.5"], + }, + { + comment: "Overloaded methods with different parameter counts", + program: ` + public class Main { + public static void f(int x) { + System.out.println("single param: " + x); + } + public static void f(int x, int y) { + System.out.println("two params: " + (x + y)); + } + public static void main(String[] args) { + f(3); + f(3, 4); + } + } + `, + expectedLines: ["single param: 3", "two params: 7"], + }, + { + comment: "Method overloading with different return types", + program: ` + public class Main { + public static int f(int x) { + return x * 2; + } + public static String f(String s) { + return s + "!"; + } + public static void main(String[] args) { + System.out.println(f(4)); + System.out.println(f("Hello")); + } + } + `, + expectedLines: ["8", "Hello!"], + }, + { + comment: "Overloading with implicit type conversion", + program: ` + public class Main { + public static void f(int x) { + System.out.println("int version: " + x); + } + public static void f(long x) { + System.out.println("long version: " + x); + } + public static void main(String[] args) { + f(10); // should call int version + f(10L); // should call long version + } + } + `, + expectedLines: ["int version: 10", "long version: 10"], + }, + { + comment: "Ambiguous method overloading", + program: ` + public class Main { + public static void f(int x, double y) { + System.out.println("int, double"); + } + public static void f(double x, int y) { + System.out.println("double, int"); + } + public static void main(String[] args) { + f(5, 5.0); + f(5.0, 5); + } + } + `, + expectedLines: ["int, double", "double, int"], + }, + { + comment: "Overloading with reference types", + program: ` + public class Main { + public static void f(String s) { + System.out.println("String"); + } + public static void f(Main m) { + System.out.println("Main"); + } + public static void main(String[] args) { + f("Hello"); // should call String version + f(new Main()); // should call Main version + } + } + `, + expectedLines: ["String", "Main"], + }, + { + comment: "Overloaded instance and static methods", + program: ` + public class Main { + public void f() { + System.out.println("Instance method"); + } + public static void f(int x) { + System.out.println("Static method with int: " + x); + } + public static void main(String[] args) { + Main obj = new Main(); + obj.f(); + f(5); + } + } + `, + expectedLines: ["Instance method", "Static method with int: 5"], + }, + { + comment: "Overloaded instance methods", + program: ` + public class Main { + public void f(int x) { + System.out.println("Instance int: " + x); + } + public void f(double x) { + System.out.println("Instance double: " + x); + } + public static void main(String[] args) { + Main obj = new Main(); + obj.f(5); + obj.f(5.5); + } + } + `, + expectedLines: ["Instance int: 5", "Instance double: 5.5"], + }, + { + comment: "Implicit conversion during method invocation", + program: ` + public class Main { + public static void f(double x) { + System.out.println("Converted double: " + x); + } + public static void main(String[] args) { + f(10); // Implicitly converts int to double + } + } + `, + expectedLines: ["Converted double: 10.0"], + }, + { + comment: "Overloading with widening conversion", + program: ` + public class Main { + public static void f(long x) { + System.out.println("long version: " + x); + } + public static void f(double x) { + System.out.println("double version: " + x); + } + public static void main(String[] args) { + f(5); // Should call long version + f(5.0f); // Should call double version + } + } + `, + expectedLines: ["long version: 5", "double version: 5.0"], + } +]; + +export const methodOverloadingTest = () => describe("method overloading", () => { + for (let testCase of testCases) { + const { comment, program, expectedLines } = testCase; + it(comment, () => runTest(program, expectedLines)); + } +}); \ No newline at end of file diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index a59ba776..e290532d 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -335,6 +335,16 @@ function hashCode(str: string): number { // return stackChange; // Return the net change in stack size // } +function getExpressionType(node: Node, cg: CodeGenerator): string { + if (!(node.kind in codeGenerators)) { + throw new ConstructNotSupportedError(node.kind); + } + const originalCode = [...cg.code]; // Preserve the original code state + const resultType = codeGenerators[node.kind](node, cg).resultType; + cg.code = originalCode; // Restore the original code state + return resultType; +} + const isNullLiteral = (node: Node) => { return node.kind === 'Literal' && node.literalType.kind === 'NullLiteral' } @@ -778,24 +788,51 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi ) } - const argTypes: Array = [] - - const methodInfo = symbolInfos[symbolInfos.length - 1] as MethodInfos - if (!methodInfo || methodInfo.length === 0) { + const methodInfos = symbolInfos[symbolInfos.length - 1] as MethodInfos + if (!methodInfos || methodInfos.length === 0) { throw new Error(`No method information found for ${n.identifier}`) } - const fullDescriptor = methodInfo[0].typeDescriptor // Full descriptor, e.g., "(Ljava/lang/String;C)V" - const paramDescriptor = fullDescriptor.slice(1, fullDescriptor.indexOf(')')) // Extract "Ljava/lang/String;C" - const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) + const methodMatches: MethodInfos = []; + const argumentDescriptors = n.argumentList.map(arg => getExpressionType(arg, cg)); - // Parse individual parameter types - if (params && params.length !== n.argumentList.length) { - throw new Error( - `Parameter mismatch: expected ${params?.length || 0}, got ${n.argumentList.length}` - ) + for (const methodInfo of methodInfos) { + const paramDescriptor = methodInfo.typeDescriptor.slice(1, methodInfo.typeDescriptor.indexOf(')')); + const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || []; + + if (params && params.length === argumentDescriptors.length) { + let match = true; + + for (let i = 0; i < params.length; i++) { + if ((argumentDescriptors[i] == 'B' || argumentDescriptors[i] == 'S') + && paramDescriptor[i] == 'I') { + continue + } + + if ((params[i] !== argumentDescriptors[i]) && + (typeConversionsImplicit[`${argumentDescriptors[i]}->${params[i]}`] === undefined) && + (!areClassTypesCompatible(argumentDescriptors[i], params[i]))) { + match = false; + break; + } + } + if (match) { + methodMatches.push(methodInfo); + } + } } + if (methodMatches.length === 0) { + throw new InvalidMethodCallError( + `No method matching signature ${n.identifier}(${argumentDescriptors.join(',')}) found.` + ); + } + + const selectedMethod = methodMatches[0]; + const fullDescriptor = selectedMethod.typeDescriptor // Full descriptor, e.g., "(Ljava/lang/String;C)V" + const paramDescriptor = fullDescriptor.slice(1, fullDescriptor.indexOf(')')) // Extract "Ljava/lang/String;C" + const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] + n.argumentList.forEach((x, i) => { const argCompileResult = compile(x, cg) @@ -807,47 +844,29 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const expectedType = params?.[i] // Expected parameter type const stackSizeChange = handleImplicitTypeConversion(normalizedType, expectedType ?? '', cg) maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize + stackSizeChange) - - argTypes.push(normalizedType) }) - const argDescriptor = '(' + argTypes.join('') + ')' - - let foundMethod = false - const methodInfos = symbolInfos[symbolInfos.length - 1] as MethodInfos - for (let i = 0; i < methodInfos.length; i++) { - const methodInfo = methodInfos[i] - if (methodInfo.typeDescriptor.includes(argDescriptor)) { - const method = cg.constantPoolManager.indexMethodrefInfo( - methodInfo.parentClassName, - methodInfo.name, - methodInfo.typeDescriptor - ) - if ( - n.identifier.startsWith('this.') && - !(methodInfo.accessFlags & FIELD_FLAGS.ACC_STATIC) - ) { - // load "this" - cg.code.push(OPCODE.ALOAD, 0) - } - cg.code.push( - methodInfo.accessFlags & METHOD_FLAGS.ACC_STATIC - ? OPCODE.INVOKESTATIC - : OPCODE.INVOKEVIRTUAL, - 0, - method - ) - resultType = methodInfo.typeDescriptor.slice(argDescriptor.length) - foundMethod = true - break - } - } - if (!foundMethod) { - throw new InvalidMethodCallError( - `No method matching signature ${n.identifier}${argDescriptor} found.` - ) + const method = cg.constantPoolManager.indexMethodrefInfo( + selectedMethod.parentClassName, + selectedMethod.name, + selectedMethod.typeDescriptor + ); + if ( + n.identifier.startsWith('this.') && + !(selectedMethod.accessFlags & FIELD_FLAGS.ACC_STATIC) + ) { + cg.code.push(OPCODE.ALOAD, 0); } - return { stackSize: maxStack, resultType: resultType } + cg.code.push( + selectedMethod.accessFlags & METHOD_FLAGS.ACC_STATIC + ? OPCODE.INVOKESTATIC + : OPCODE.INVOKEVIRTUAL, + 0, + method + ); + resultType = selectedMethod.typeDescriptor.slice(selectedMethod.typeDescriptor.indexOf(')') + 1); + + return { stackSize: maxStack, resultType: resultType }; }, Assignment: (node: Node, cg: CodeGenerator) => { From c12fbb8ac3c739d704e4c6a391ee8fd8dc495e18 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 4 Mar 2025 11:18:53 +0800 Subject: [PATCH 18/29] Handle ambiguity during overloading in the compiler --- src/compiler/code-generator.ts | 38 ++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index e290532d..59ffcc5e 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -345,6 +345,12 @@ function getExpressionType(node: Node, cg: CodeGenerator): string { return resultType; } +function isSubtype(fromType: string, toType: string): boolean { + return (fromType === toType) || + (typeConversionsImplicit[`${fromType}->${toType}`] !== undefined) || + (areClassTypesCompatible(fromType, toType)) +} + const isNullLiteral = (node: Node) => { return node.kind === 'Literal' && node.literalType.kind === 'NullLiteral' } @@ -809,9 +815,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi continue } - if ((params[i] !== argumentDescriptors[i]) && - (typeConversionsImplicit[`${argumentDescriptors[i]}->${params[i]}`] === undefined) && - (!areClassTypesCompatible(argumentDescriptors[i], params[i]))) { + if (!isSubtype(argumentDescriptors[i], params[i])) { match = false; break; } @@ -828,7 +832,33 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi ); } - const selectedMethod = methodMatches[0]; + console.log(methodMatches) + + let selectedMethod = methodMatches[0] + + if (methodMatches.length > 1) { + for (let i = 1; i < methodMatches.length; i++) { + const paramDescriptor = methodMatches[i].typeDescriptor.slice(1, methodMatches[i].typeDescriptor.indexOf(')')); + const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || []; + + const paramDescriptorCurrent = selectedMethod.typeDescriptor.slice(1, selectedMethod.typeDescriptor.indexOf(')')); + const paramsCurrent = paramDescriptorCurrent.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || []; + + if (params.map((param, index) => isSubtype(param, paramsCurrent[index])).reduce((a, b) => a && b)) { + selectedMethod = methodMatches[i]; + console.debug('This') + } else if (paramsCurrent.map((param, index) => isSubtype(param, params[index])).reduce((a, b) => a && b)) { + // do nothing + console.debug('Other') + } else { + console.debug('Ambiguous') + throw new InvalidMethodCallError( + `Ambiguous method call: ${n.identifier}(${argumentDescriptors.join(',')})` + ) + } + } + } + const fullDescriptor = selectedMethod.typeDescriptor // Full descriptor, e.g., "(Ljava/lang/String;C)V" const paramDescriptor = fullDescriptor.slice(1, fullDescriptor.indexOf(')')) // Extract "Ljava/lang/String;C" const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] From 28a59e1e2351dd92df9fc7b53564bb9305657abb Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Tue, 4 Mar 2025 13:45:42 +0800 Subject: [PATCH 19/29] Modify type checker for overloading support --- src/types/types/methods.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types/types/methods.ts b/src/types/types/methods.ts index 7c8c9365..d7df067a 100644 --- a/src/types/types/methods.ts +++ b/src/types/types/methods.ts @@ -55,7 +55,7 @@ export class Parameter { return ( object instanceof Parameter && this._name === object._name && - (this._type.canBeAssigned(object._type) || object._type.canBeAssigned(this._type)) && + this._type === object._type && this._isVarargs === object._isVarargs ) } From f0c86ed036c5e488b63fc2cfdf7067b18529e547 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Thu, 3 Apr 2025 11:34:14 +0800 Subject: [PATCH 20/29] Change symbol-table.ts and compiler.ts logic for overriding. Add test cases for method overriding. --- .../__tests__/__utils__/test-utils.ts | 56 +- src/compiler/__tests__/index.ts | 2 + src/compiler/code-generator.ts | 670 +++++++++--------- src/compiler/compiler.ts | 48 +- src/compiler/error.ts | 20 +- src/compiler/import/lib-info.ts | 75 +- src/compiler/index.ts | 4 +- src/compiler/symbol-table.ts | 84 ++- 8 files changed, 535 insertions(+), 424 deletions(-) diff --git a/src/compiler/__tests__/__utils__/test-utils.ts b/src/compiler/__tests__/__utils__/test-utils.ts index 36d11f57..39b6f156 100644 --- a/src/compiler/__tests__/__utils__/test-utils.ts +++ b/src/compiler/__tests__/__utils__/test-utils.ts @@ -1,45 +1,47 @@ -import { inspect } from "util"; -import { compile } from "../../index"; -import { BinaryWriter } from "../../binary-writer"; -import { AST } from "../../../ast/types/packages-and-modules"; -import { javaPegGrammar } from "../../grammar" +import { inspect } from 'util' +import { compile } from '../../index' +import { BinaryWriter } from '../../binary-writer' +import { AST } from '../../../ast/types/packages-and-modules' +import { javaPegGrammar } from '../../grammar' import { peggyFunctions } from '../../peggy-functions' -import { execSync } from "child_process"; +import { execSync } from 'child_process' -import * as peggy from "peggy"; -import * as fs from "fs"; +import * as peggy from 'peggy' +import * as fs from 'fs' export type testCase = { - comment: string, - program: string, - expectedLines: string[], + comment: string + program: string + expectedLines: string[] } -const debug = false; -const pathToTestDir = "./src/compiler/__tests__/"; +const debug = false +const pathToTestDir = './src/compiler/__tests__/' const parser = peggy.generate(peggyFunctions + javaPegGrammar, { - allowedStartRules: ["CompilationUnit"], -}); -const binaryWriter = new BinaryWriter(); + allowedStartRules: ['CompilationUnit'] +}) +const binaryWriter = new BinaryWriter() export function runTest(program: string, expectedLines: string[]) { - const ast = parser.parse(program); - expect(ast).not.toBeNull(); + const ast = parser.parse(program) + expect(ast).not.toBeNull() if (debug) { - console.log(inspect(ast, false, null, true)); + console.log(inspect(ast, false, null, true)) } - const classFile = compile(ast as AST); - binaryWriter.writeBinary(classFile, pathToTestDir); + const classFiles = compile(ast as AST) + for (let classFile of classFiles) { + binaryWriter.writeBinary(classFile, pathToTestDir) + } - const prevDir = process.cwd(); - process.chdir(pathToTestDir); - execSync("java -noverify Main > output.log 2> err.log"); + const prevDir = process.cwd() + process.chdir(pathToTestDir) + execSync('java -noverify Main > output.log 2> err.log') // ignore difference between \r\n and \n - const actualLines = fs.readFileSync("./output.log", 'utf-8').split(/\r?\n/).slice(0, -1); - process.chdir(prevDir); + const actualLines = fs.readFileSync('./output.log', 'utf-8').split(/\r?\n/).slice(0, -1) + process.chdir(prevDir) - expect(actualLines).toStrictEqual(expectedLines); + expect(actualLines).toStrictEqual(expectedLines) } diff --git a/src/compiler/__tests__/index.ts b/src/compiler/__tests__/index.ts index ba44159d..5ea3b2b8 100644 --- a/src/compiler/__tests__/index.ts +++ b/src/compiler/__tests__/index.ts @@ -13,8 +13,10 @@ import { assignmentExpressionTest } from './tests/assignmentExpression.test' import { castExpressionTest } from './tests/castExpression.test' import { switchTest } from './tests/switch.test' import { methodOverloadingTest } from './tests/methodOverloading.test' +import { methodOverridingTest } from './tests/methodOverriding.test' describe('compiler tests', () => { + methodOverridingTest() methodOverloadingTest() switchTest() castExpressionTest() diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 59ffcc5e..24b9d76c 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -24,11 +24,20 @@ import { ClassInstanceCreationExpression, ExpressionStatement, TernaryExpression, - LeftHandSide, CastExpression, SwitchStatement, SwitchCase, CaseLabel + LeftHandSide, + CastExpression, + SwitchStatement, + SwitchCase, + CaseLabel } from '../ast/types/blocks-and-statements' import { MethodDeclaration, UnannType } from '../ast/types/classes' import { ConstantPoolManager } from './constant-pool-manager' -import { ConstructNotSupportedError, InvalidMethodCallError } from './error' +import { + AmbiguousMethodCallError, + ConstructNotSupportedError, + MethodNotFoundError, + NoMethodMatchingSignatureError +} from './error' import { FieldInfo, MethodInfos, SymbolInfo, SymbolTable, VariableInfo } from './symbol-table' type Label = { @@ -180,7 +189,7 @@ const typeConversions: { [key: string]: OPCODE } = { 'J->I': OPCODE.L2I, 'J->F': OPCODE.L2F, 'J->D': OPCODE.L2D -}; +} const typeConversionsImplicit: { [key: string]: OPCODE } = { 'I->F': OPCODE.I2F, @@ -205,12 +214,12 @@ function areClassTypesCompatible(fromType: string, toType: string): boolean { function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator): number { if (fromType === toType || toType.replace(/^L|;$/g, '') === 'java/lang/String') { - return 0; + return 0 } if (fromType.startsWith('L') || toType.startsWith('L')) { if (areClassTypesCompatible(fromType, toType) || fromType === '') { - return 0; + return 0 } throw new Error(`Unsupported class type conversion: ${fromType} -> ${toType}`) } @@ -219,11 +228,11 @@ function handleImplicitTypeConversion(fromType: string, toType: string, cg: Code if (conversionKey in typeConversionsImplicit) { cg.code.push(typeConversionsImplicit[conversionKey]) if (!(fromType in ['J', 'D']) && toType in ['J', 'D']) { - return 1; + return 1 } else if (!(toType in ['J', 'D']) && fromType in ['J', 'D']) { - return -1; + return -1 } else { - return 0; + return 0 } } else { throw new Error(`Unsupported implicit type conversion: ${conversionKey}`) @@ -232,18 +241,18 @@ function handleImplicitTypeConversion(fromType: string, toType: string, cg: Code function handleExplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { if (fromType === toType) { - return; + return } - const conversionKey = `${fromType}->${toType}`; + const conversionKey = `${fromType}->${toType}` if (conversionKey in typeConversions) { - cg.code.push(typeConversions[conversionKey]); + cg.code.push(typeConversions[conversionKey]) } else { - throw new Error(`Unsupported explicit type conversion: ${conversionKey}`); + throw new Error(`Unsupported explicit type conversion: ${conversionKey}`) } } function generateStringConversion(valueType: string, cg: CodeGenerator): void { - const stringClass = 'java/lang/String'; + const stringClass = 'java/lang/String' // Map primitive types to `String.valueOf()` method descriptors const valueOfDescriptors: { [key: string]: string } = { @@ -254,29 +263,25 @@ function generateStringConversion(valueType: string, cg: CodeGenerator): void { Z: '(Z)Ljava/lang/String;', // boolean B: '(B)Ljava/lang/String;', // byte S: '(S)Ljava/lang/String;', // short - C: '(C)Ljava/lang/String;' // char - }; + C: '(C)Ljava/lang/String;' // char + } - const descriptor = valueOfDescriptors[valueType]; + const descriptor = valueOfDescriptors[valueType] if (!descriptor) { - throw new Error(`Unsupported primitive type for String conversion: ${valueType}`); + throw new Error(`Unsupported primitive type for String conversion: ${valueType}`) } - const methodIndex = cg.constantPoolManager.indexMethodrefInfo( - stringClass, - 'valueOf', - descriptor - ); + const methodIndex = cg.constantPoolManager.indexMethodrefInfo(stringClass, 'valueOf', descriptor) - cg.code.push(OPCODE.INVOKESTATIC, 0, methodIndex); + cg.code.push(OPCODE.INVOKESTATIC, 0, methodIndex) } function hashCode(str: string): number { - let hash = 0; + let hash = 0 for (let i = 0; i < str.length; i++) { - hash = ((hash * 31) + str.charCodeAt(i)); // Simulate Java's overflow behavior + hash = hash * 31 + str.charCodeAt(i) // Simulate Java's overflow behavior } - return hash; + return hash } // function generateBooleanConversion(type: string, cg: CodeGenerator): number { @@ -337,18 +342,20 @@ function hashCode(str: string): number { function getExpressionType(node: Node, cg: CodeGenerator): string { if (!(node.kind in codeGenerators)) { - throw new ConstructNotSupportedError(node.kind); + throw new ConstructNotSupportedError(node.kind) } - const originalCode = [...cg.code]; // Preserve the original code state - const resultType = codeGenerators[node.kind](node, cg).resultType; - cg.code = originalCode; // Restore the original code state - return resultType; + const originalCode = [...cg.code] // Preserve the original code state + const resultType = codeGenerators[node.kind](node, cg).resultType + cg.code = originalCode // Restore the original code state + return resultType } function isSubtype(fromType: string, toType: string): boolean { - return (fromType === toType) || - (typeConversionsImplicit[`${fromType}->${toType}`] !== undefined) || - (areClassTypesCompatible(fromType, toType)) + return ( + fromType === toType || + typeConversionsImplicit[`${fromType}->${toType}`] !== undefined || + areClassTypesCompatible(fromType, toType) + ) } const isNullLiteral = (node: Node) => { @@ -434,7 +441,11 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.code.push(OPCODE.ASTORE, curIdx) } else { const { stackSize: initializerStackSize, resultType: initializerType } = compile(vi, cg) - const stackSizeChange = handleImplicitTypeConversion(initializerType, variableInfo.typeDescriptor, cg) + const stackSizeChange = handleImplicitTypeConversion( + initializerType, + variableInfo.typeDescriptor, + cg + ) maxStack = Math.max(maxStack, initializerStackSize + stackSizeChange) cg.code.push( variableInfo.typeDescriptor in normalStoreOp @@ -462,14 +473,14 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi BreakStatement: (node: Node, cg: CodeGenerator) => { if (cg.loopLabels.length > 0) { // If inside a loop, break jumps to the end of the loop - cg.addBranchInstr(OPCODE.GOTO, cg.loopLabels[cg.loopLabels.length - 1][1]); + cg.addBranchInstr(OPCODE.GOTO, cg.loopLabels[cg.loopLabels.length - 1][1]) } else if (cg.switchLabels.length > 0) { // If inside a switch, break jumps to the switch's end label - cg.addBranchInstr(OPCODE.GOTO, cg.switchLabels[cg.switchLabels.length - 1]); + cg.addBranchInstr(OPCODE.GOTO, cg.switchLabels[cg.switchLabels.length - 1]) } else { - throw new Error("Break statement not inside a loop or switch statement"); + throw new Error('Break statement not inside a loop or switch statement') } - return { stackSize: 0, resultType: EMPTY_TYPE }; + return { stackSize: 0, resultType: EMPTY_TYPE } }, ContinueStatement: (node: Node, cg: CodeGenerator) => { @@ -767,136 +778,94 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const n = node as MethodInvocation let maxStack = 1 let resultType = EMPTY_TYPE + const candidateMethods: MethodInfos = [] - const symbolInfos = cg.symbolTable.queryMethod(n.identifier) - if (!symbolInfos || symbolInfos.length === 0) { - throw new Error(`Method not found: ${n.identifier}`) + // TODO: Write logic to get candidateMethods + // --- Handle super. calls --- + if (n.identifier.startsWith('super.')) { } - - for (let i = 0; i < symbolInfos.length - 1; i++) { - if (i === 0) { - const varInfo = symbolInfos[i] as VariableInfo - if (varInfo.index !== undefined) { - cg.code.push(OPCODE.ALOAD, varInfo.index) - continue - } - } - const fieldInfo = symbolInfos[i] as FieldInfo - const field = cg.constantPoolManager.indexFieldrefInfo( - fieldInfo.parentClassName, - fieldInfo.name, - fieldInfo.typeDescriptor - ) - cg.code.push( - fieldInfo.accessFlags & FIELD_FLAGS.ACC_STATIC ? OPCODE.GETSTATIC : OPCODE.GETFIELD, - 0, - field - ) + // --- Handle qualified calls (e.g. System.out.println or p.show) --- + else if (n.identifier.includes('.')) { } - - const methodInfos = symbolInfos[symbolInfos.length - 1] as MethodInfos - if (!methodInfos || methodInfos.length === 0) { - throw new Error(`No method information found for ${n.identifier}`) + // --- Handle unqualified calls (including this.method()) --- + else { } - const methodMatches: MethodInfos = []; - const argumentDescriptors = n.argumentList.map(arg => getExpressionType(arg, cg)); - - for (const methodInfo of methodInfos) { - const paramDescriptor = methodInfo.typeDescriptor.slice(1, methodInfo.typeDescriptor.indexOf(')')); - const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || []; - - if (params && params.length === argumentDescriptors.length) { - let match = true; - - for (let i = 0; i < params.length; i++) { - if ((argumentDescriptors[i] == 'B' || argumentDescriptors[i] == 'S') - && paramDescriptor[i] == 'I') { - continue - } - - if (!isSubtype(argumentDescriptors[i], params[i])) { - match = false; - break; - } - } - if (match) { - methodMatches.push(methodInfo); + // Filter candidate methods by matching the argument list. + const argDescs = n.argumentList.map(arg => getExpressionType(arg, cg)) + const methodMatches: MethodInfos = [] + for (let i = 0; i < candidateMethods.length; i++) { + const m = candidateMethods[i] + const fullDesc = m.typeDescriptor // e.g., "(Ljava/lang/String;C)V" + const paramPart = fullDesc.slice(1, fullDesc.indexOf(')')) + const params = paramPart.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] + if (params.length !== argDescs.length) continue + let match = true + for (let i = 0; i < params.length; i++) { + const argType = argDescs[i] + // Allow B/S to match int. + if ((argType === 'B' || argType === 'S') && params[i] === 'I') continue + if (!isSubtype(argType, params[i])) { + match = false + break } } + if (match) methodMatches.push(m) } - if (methodMatches.length === 0) { - throw new InvalidMethodCallError( - `No method matching signature ${n.identifier}(${argumentDescriptors.join(',')}) found.` - ); + throw new NoMethodMatchingSignatureError(n.identifier + argDescs.join(',')) } - console.log(methodMatches) - + // Overload resolution (simple: choose first, or refine if needed) let selectedMethod = methodMatches[0] - if (methodMatches.length > 1) { for (let i = 1; i < methodMatches.length; i++) { - const paramDescriptor = methodMatches[i].typeDescriptor.slice(1, methodMatches[i].typeDescriptor.indexOf(')')); - const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || []; - - const paramDescriptorCurrent = selectedMethod.typeDescriptor.slice(1, selectedMethod.typeDescriptor.indexOf(')')); - const paramsCurrent = paramDescriptorCurrent.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || []; - - if (params.map((param, index) => isSubtype(param, paramsCurrent[index])).reduce((a, b) => a && b)) { - selectedMethod = methodMatches[i]; - console.debug('This') - } else if (paramsCurrent.map((param, index) => isSubtype(param, params[index])).reduce((a, b) => a && b)) { - // do nothing - console.debug('Other') - } else { - console.debug('Ambiguous') - throw new InvalidMethodCallError( - `Ambiguous method call: ${n.identifier}(${argumentDescriptors.join(',')})` - ) + const currParams = + selectedMethod.typeDescriptor + .slice(1, selectedMethod.typeDescriptor.indexOf(')')) + .match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] + const candParams = + methodMatches[i].typeDescriptor + .slice(1, methodMatches[i].typeDescriptor.indexOf(')')) + .match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] + if ( + candParams.map((p, idx) => isSubtype(p, currParams[idx])).reduce((a, b) => a && b, true) + ) { + selectedMethod = methodMatches[i] + } else if ( + !currParams.map((p, idx) => isSubtype(p, candParams[idx])).reduce((a, b) => a && b, true) + ) { + throw new AmbiguousMethodCallError(n.identifier + argDescs.join(',')) } } } - const fullDescriptor = selectedMethod.typeDescriptor // Full descriptor, e.g., "(Ljava/lang/String;C)V" - const paramDescriptor = fullDescriptor.slice(1, fullDescriptor.indexOf(')')) // Extract "Ljava/lang/String;C" - const params = paramDescriptor.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] - - n.argumentList.forEach((x, i) => { - const argCompileResult = compile(x, cg) - - let normalizedType = argCompileResult.resultType; - if (normalizedType === 'B' || normalizedType === 'S') { - normalizedType = 'I' - } - - const expectedType = params?.[i] // Expected parameter type - const stackSizeChange = handleImplicitTypeConversion(normalizedType, expectedType ?? '', cg) - maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize + stackSizeChange) + // Compile each argument. + const fullDescriptor = selectedMethod.typeDescriptor + const paramPart = fullDescriptor.slice(1, fullDescriptor.indexOf(')')) + const params = paramPart.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] + n.argumentList.forEach((arg, i) => { + const argRes = compile(arg, cg) + let argType = argRes.resultType + if (argType === 'B' || argType === 'S') argType = 'I' + const conv = handleImplicitTypeConversion(argType, params[i] || '', cg) + maxStack = Math.max(maxStack, i + 1 + argRes.stackSize + conv) }) - const method = cg.constantPoolManager.indexMethodrefInfo( - selectedMethod.parentClassName, + // Emit the method call. + const methodRef = cg.constantPoolManager.indexMethodrefInfo( + selectedMethod.className, selectedMethod.name, selectedMethod.typeDescriptor - ); - if ( - n.identifier.startsWith('this.') && - !(selectedMethod.accessFlags & FIELD_FLAGS.ACC_STATIC) - ) { - cg.code.push(OPCODE.ALOAD, 0); + ) + if (n.identifier.startsWith('super.')) { + cg.code.push(OPCODE.INVOKESPECIAL, 0, methodRef) + } else { + const isStatic = (selectedMethod.accessFlags & METHOD_FLAGS.ACC_STATIC) !== 0 + cg.code.push(isStatic ? OPCODE.INVOKESTATIC : OPCODE.INVOKEVIRTUAL, 0, methodRef) } - cg.code.push( - selectedMethod.accessFlags & METHOD_FLAGS.ACC_STATIC - ? OPCODE.INVOKESTATIC - : OPCODE.INVOKEVIRTUAL, - 0, - method - ); - resultType = selectedMethod.typeDescriptor.slice(selectedMethod.typeDescriptor.indexOf(')') + 1); - - return { stackSize: maxStack, resultType: resultType }; + resultType = selectedMethod.typeDescriptor.slice(selectedMethod.typeDescriptor.indexOf(')') + 1) + return { stackSize: maxStack, resultType: resultType } }, Assignment: (node: Node, cg: CodeGenerator) => { @@ -1009,19 +978,17 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } const { stackSize: size1, resultType: leftType } = compile(left, cg) - const insertConversionIndex = cg.code.length; - cg.code.push(OPCODE.NOP); + const insertConversionIndex = cg.code.length + cg.code.push(OPCODE.NOP) const { stackSize: size2, resultType: rightType } = compile(right, cg) - if (op === '+' && - (leftType === 'Ljava/lang/String;' - || rightType === 'Ljava/lang/String;')) { + if (op === '+' && (leftType === 'Ljava/lang/String;' || rightType === 'Ljava/lang/String;')) { if (leftType !== 'Ljava/lang/String;') { - generateStringConversion(leftType, cg); + generateStringConversion(leftType, cg) } if (rightType !== 'Ljava/lang/String;') { - generateStringConversion(rightType, cg); + generateStringConversion(rightType, cg) } // Invoke `String.concat` for concatenation @@ -1029,16 +996,16 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi 'java/lang/String', 'concat', '(Ljava/lang/String;)Ljava/lang/String;' - ); - cg.code.push(OPCODE.INVOKEVIRTUAL, 0, concatMethodIndex); + ) + cg.code.push(OPCODE.INVOKEVIRTUAL, 0, concatMethodIndex) return { stackSize: Math.max(size1 + 1, size2 + 1), // Max stack size plus one for the concatenation resultType: 'Ljava/lang/String;' - }; + } } - let finalType = leftType; + let finalType = leftType if (leftType !== rightType) { const conversionKeyLeft = `${leftType}->${rightType}` @@ -1047,63 +1014,75 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi if (['D', 'F'].includes(leftType) || ['D', 'F'].includes(rightType)) { // Promote both to double if one is double, or to float otherwise if (leftType !== 'D' && rightType === 'D') { - cg.code.fill(typeConversionsImplicit[conversionKeyLeft], - insertConversionIndex, insertConversionIndex + 1) - finalType = 'D'; + cg.code.fill( + typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, + insertConversionIndex + 1 + ) + finalType = 'D' } else if (leftType === 'D' && rightType !== 'D') { cg.code.push(typeConversionsImplicit[conversionKeyRight]) - finalType = 'D'; + finalType = 'D' } else if (leftType !== 'F' && rightType === 'F') { // handleImplicitTypeConversion(leftType, 'F', cg); - cg.code.fill(typeConversionsImplicit[conversionKeyLeft], - insertConversionIndex, insertConversionIndex + 1) - finalType = 'F'; + cg.code.fill( + typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, + insertConversionIndex + 1 + ) + finalType = 'F' } else if (leftType === 'F' && rightType !== 'F') { cg.code.push(typeConversionsImplicit[conversionKeyRight]) - finalType = 'F'; + finalType = 'F' } } else if (['J'].includes(leftType) || ['J'].includes(rightType)) { // Promote both to long if one is long if (leftType !== 'J' && rightType === 'J') { - cg.code.fill(typeConversionsImplicit[conversionKeyLeft], - insertConversionIndex, insertConversionIndex + 1) + cg.code.fill( + typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, + insertConversionIndex + 1 + ) } else if (leftType === 'J' && rightType !== 'J') { cg.code.push(typeConversionsImplicit[conversionKeyRight]) } - finalType = 'J'; + finalType = 'J' } else { // Promote both to int as the common type for smaller types like byte, short, char if (leftType !== 'I') { - cg.code.fill(typeConversionsImplicit[conversionKeyLeft], - insertConversionIndex, insertConversionIndex + 1) + cg.code.fill( + typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, + insertConversionIndex + 1 + ) } if (rightType !== 'I') { cg.code.push(typeConversionsImplicit[conversionKeyRight]) } - finalType = 'I'; + finalType = 'I' } } // Perform the operation switch (finalType) { case 'B': - cg.code.push(intBinaryOp[op], OPCODE.I2B); - break; + cg.code.push(intBinaryOp[op], OPCODE.I2B) + break case 'D': - cg.code.push(doubleBinaryOp[op]); - break; + cg.code.push(doubleBinaryOp[op]) + break case 'F': - cg.code.push(floatBinaryOp[op]); - break; + cg.code.push(floatBinaryOp[op]) + break case 'I': - cg.code.push(intBinaryOp[op]); - break; + cg.code.push(intBinaryOp[op]) + break case 'J': - cg.code.push(longBinaryOp[op]); - break; + cg.code.push(longBinaryOp[op]) + break case 'S': - cg.code.push(intBinaryOp[op], OPCODE.I2S); - break; + cg.code.push(intBinaryOp[op], OPCODE.I2S) + break } return { @@ -1148,13 +1127,13 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi I: OPCODE.INEG, // Integer negation J: OPCODE.LNEG, // Long negation F: OPCODE.FNEG, // Float negation - D: OPCODE.DNEG, // Double negation - }; + D: OPCODE.DNEG // Double negation + } if (compileResult.resultType in negationOpcodes) { - cg.code.push(negationOpcodes[compileResult.resultType]); + cg.code.push(negationOpcodes[compileResult.resultType]) } else { - throw new Error(`Unary '-' not supported for type: ${compileResult.resultType}`); + throw new Error(`Unary '-' not supported for type: ${compileResult.resultType}`) } } else if (op === '~') { cg.code.push(OPCODE.ICONST_M1, OPCODE.IXOR) @@ -1307,18 +1286,18 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi }, CastExpression: (node: Node, cg: CodeGenerator) => { - const { expression, type } = node as CastExpression; // CastExpression node structure - const { stackSize, resultType } = compile(expression, cg); + const { expression, type } = node as CastExpression // CastExpression node structure + const { stackSize, resultType } = compile(expression, cg) if ((type == 'byte' || type == 'short') && resultType != 'I') { - handleExplicitTypeConversion(resultType, 'I', cg); - handleExplicitTypeConversion('I', cg.symbolTable.generateFieldDescriptor(type), cg); + handleExplicitTypeConversion(resultType, 'I', cg) + handleExplicitTypeConversion('I', cg.symbolTable.generateFieldDescriptor(type), cg) } else if (resultType == 'C') { if (type == 'int') { return { stackSize, - resultType: cg.symbolTable.generateFieldDescriptor('int'), - }; + resultType: cg.symbolTable.generateFieldDescriptor('int') + } } else { throw new Error(`Unsupported class type conversion: ${'C'} -> ${cg.symbolTable.generateFieldDescriptor(type)}`) @@ -1331,236 +1310,264 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi ${resultType} -> ${cg.symbolTable.generateFieldDescriptor(type)}`) } } else { - handleExplicitTypeConversion(resultType, cg.symbolTable.generateFieldDescriptor(type), cg); + handleExplicitTypeConversion(resultType, cg.symbolTable.generateFieldDescriptor(type), cg) } return { stackSize, - resultType: cg.symbolTable.generateFieldDescriptor(type), + resultType: cg.symbolTable.generateFieldDescriptor(type) } }, SwitchStatement: (node: Node, cg: CodeGenerator) => { - const { expression, cases } = node as SwitchStatement; + const { expression, cases } = node as SwitchStatement // Compile the switch expression - const { stackSize: exprStackSize, resultType } = compile(expression, cg); - let maxStack = exprStackSize; + const { stackSize: exprStackSize, resultType } = compile(expression, cg) + let maxStack = exprStackSize - const caseLabels: Label[] = cases.map(() => cg.generateNewLabel()); - const defaultLabel = cg.generateNewLabel(); - const endLabel = cg.generateNewLabel(); + const caseLabels: Label[] = cases.map(() => cg.generateNewLabel()) + const defaultLabel = cg.generateNewLabel() + const endLabel = cg.generateNewLabel() // Track the switch statement's end label - cg.switchLabels.push(endLabel); + cg.switchLabels.push(endLabel) - if (["I", "B", "S", "C"].includes(resultType)) { - const caseValues: number[] = []; - const caseLabelMap: Map = new Map(); - let hasDefault = false; - const positionOffset = cg.code.length; + if (['I', 'B', 'S', 'C'].includes(resultType)) { + const caseValues: number[] = [] + const caseLabelMap: Map = new Map() + let hasDefault = false + const positionOffset = cg.code.length cases.forEach((caseGroup, index) => { - caseGroup.labels.forEach((label) => { - if (label.kind === "CaseLabel") { - const value = parseInt((label.expression as Literal).literalType.value); - caseValues.push(value); - caseLabelMap.set(value, caseLabels[index]); - } else if (label.kind === "DefaultLabel") { - caseLabels[index] = defaultLabel; - hasDefault = true; + caseGroup.labels.forEach(label => { + if (label.kind === 'CaseLabel') { + const value = parseInt((label.expression as Literal).literalType.value) + caseValues.push(value) + caseLabelMap.set(value, caseLabels[index]) + } else if (label.kind === 'DefaultLabel') { + caseLabels[index] = defaultLabel + hasDefault = true } - }); - }); + }) + }) - const [minValue, maxValue] = [Math.min(...caseValues), Math.max(...caseValues)]; - const useTableSwitch = maxValue - minValue < caseValues.length * 2; + const [minValue, maxValue] = [Math.min(...caseValues), Math.max(...caseValues)] + const useTableSwitch = maxValue - minValue < caseValues.length * 2 const caseLabelIndex: number[] = [] - let indexTracker = cg.code.length; + let indexTracker = cg.code.length if (useTableSwitch) { - cg.code.push(OPCODE.TABLESWITCH); + cg.code.push(OPCODE.TABLESWITCH) indexTracker++ // Ensure 4-byte alignment for TABLESWITCH while (cg.code.length % 4 !== 0) { - cg.code.push(0); // Padding bytes (JVM requires alignment) + cg.code.push(0) // Padding bytes (JVM requires alignment) indexTracker++ } // Add default branch (jump to default label) - cg.code.push(0, 0, 0, defaultLabel.offset); - caseLabelIndex.push(indexTracker + 3); + cg.code.push(0, 0, 0, defaultLabel.offset) + caseLabelIndex.push(indexTracker + 3) indexTracker += 4 // Push low and high values (min and max case values) - cg.code.push(minValue >> 24, (minValue >> 16) & 0xff, (minValue >> 8) & 0xff, minValue & 0xff); - cg.code.push(maxValue >> 24, (maxValue >> 16) & 0xff, (maxValue >> 8) & 0xff, maxValue & 0xff); + cg.code.push( + minValue >> 24, + (minValue >> 16) & 0xff, + (minValue >> 8) & 0xff, + minValue & 0xff + ) + cg.code.push( + maxValue >> 24, + (maxValue >> 16) & 0xff, + (maxValue >> 8) & 0xff, + maxValue & 0xff + ) indexTracker += 8 // Generate branch table (map each value to a case label) for (let i = minValue; i <= maxValue; i++) { - const caseIndex = caseValues.indexOf(i); - cg.code.push(0, 0, 0, caseIndex !== -1 ? caseLabels[caseIndex].offset - : defaultLabel.offset); - caseLabelIndex.push(indexTracker + 3); - indexTracker += 4; + const caseIndex = caseValues.indexOf(i) + cg.code.push( + 0, + 0, + 0, + caseIndex !== -1 ? caseLabels[caseIndex].offset : defaultLabel.offset + ) + caseLabelIndex.push(indexTracker + 3) + indexTracker += 4 } } else { - cg.code.push(OPCODE.LOOKUPSWITCH); - indexTracker++; + cg.code.push(OPCODE.LOOKUPSWITCH) + indexTracker++ // Ensure 4-byte alignment for LOOKUPSWITCH while (cg.code.length % 4 !== 0) { - cg.code.push(0); + cg.code.push(0) indexTracker++ } // Add default branch (jump to default label) - cg.code.push(0, 0, 0, defaultLabel.offset); - caseLabelIndex.push(indexTracker + 3); + cg.code.push(0, 0, 0, defaultLabel.offset) + caseLabelIndex.push(indexTracker + 3) indexTracker += 4 // Push the number of case-value pairs - cg.code.push((caseValues.length >> 24) & 0xff, (caseValues.length >> 16) & 0xff, - (caseValues.length >> 8) & 0xff, caseValues.length & 0xff); + cg.code.push( + (caseValues.length >> 24) & 0xff, + (caseValues.length >> 16) & 0xff, + (caseValues.length >> 8) & 0xff, + caseValues.length & 0xff + ) indexTracker += 4 // Generate lookup table (pairs of case values and corresponding labels) caseValues.forEach((value, index) => { - cg.code.push(value >> 24, (value >> 16) & 0xff, (value >> 8) & 0xff, value & 0xff); - cg.code.push(0, 0, 0, caseLabels[index].offset); - caseLabelIndex.push(indexTracker + 7); - indexTracker += 8; - }); + cg.code.push(value >> 24, (value >> 16) & 0xff, (value >> 8) & 0xff, value & 0xff) + cg.code.push(0, 0, 0, caseLabels[index].offset) + caseLabelIndex.push(indexTracker + 7) + indexTracker += 8 + }) } // **Process case bodies with proper fallthrough handling** - let previousCase: SwitchCase | null = null; + let previousCase: SwitchCase | null = null - const nonDefaultCases = cases.filter((caseGroup) => - caseGroup.labels.some((label) => label.kind === "CaseLabel")) + const nonDefaultCases = cases.filter(caseGroup => + caseGroup.labels.some(label => label.kind === 'CaseLabel') + ) nonDefaultCases.forEach((caseGroup, index) => { - caseLabels[index].offset = cg.code.length; + caseLabels[index].offset = cg.code.length // Ensure statements array is always defined - caseGroup.statements = caseGroup.statements || []; + caseGroup.statements = caseGroup.statements || [] // If previous case had no statements, merge labels (fallthrough) if (previousCase && (previousCase.statements?.length ?? 0) === 0) { - previousCase.labels.push(...caseGroup.labels); + previousCase.labels.push(...caseGroup.labels) } // Compile case statements - caseGroup.statements.forEach((statement) => { - const { stackSize } = compile(statement, cg); - maxStack = Math.max(maxStack, stackSize); - }); + caseGroup.statements.forEach(statement => { + const { stackSize } = compile(statement, cg) + maxStack = Math.max(maxStack, stackSize) + }) - previousCase = caseGroup; - }); + previousCase = caseGroup + }) // **Process default case** - defaultLabel.offset = cg.code.length; + defaultLabel.offset = cg.code.length if (hasDefault) { - const defaultCase = cases.find((caseGroup) => - caseGroup.labels.some((label) => label.kind === "DefaultLabel") - ); + const defaultCase = cases.find(caseGroup => + caseGroup.labels.some(label => label.kind === 'DefaultLabel') + ) if (defaultCase) { - defaultCase.statements = defaultCase.statements || []; - defaultCase.statements.forEach((statement) => { - const { stackSize } = compile(statement, cg); - maxStack = Math.max(maxStack, stackSize); - }); + defaultCase.statements = defaultCase.statements || [] + defaultCase.statements.forEach(statement => { + const { stackSize } = compile(statement, cg) + maxStack = Math.max(maxStack, stackSize) + }) } } - cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset; + cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset for (let i = 1; i < caseLabelIndex.length; i++) { cg.code[caseLabelIndex[i]] = caseLabels[i - 1].offset - positionOffset } - endLabel.offset = cg.code.length; - - } else if (resultType === "Ljava/lang/String;") { + endLabel.offset = cg.code.length + } else if (resultType === 'Ljava/lang/String;') { // **String Switch Handling** - const hashCaseMap: Map = new Map(); + const hashCaseMap: Map = new Map() // Compute and store hashCode() cg.code.push( OPCODE.INVOKEVIRTUAL, 0, - cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "hashCode", "()I") - ); + cg.constantPoolManager.indexMethodrefInfo('java/lang/String', 'hashCode', '()I') + ) // Create lookup table for hashCodes cases.forEach((caseGroup, index) => { - caseGroup.labels.forEach((label) => { - if (label.kind === "CaseLabel") { - const caseValue = (label.expression as Literal).literalType.value; - const hashCodeValue = hashCode(caseValue.slice(1, caseValue.length - 1)); + caseGroup.labels.forEach(label => { + if (label.kind === 'CaseLabel') { + const caseValue = (label.expression as Literal).literalType.value + const hashCodeValue = hashCode(caseValue.slice(1, caseValue.length - 1)) if (!hashCaseMap.has(hashCodeValue)) { - hashCaseMap.set(hashCodeValue, caseLabels[index]); + hashCaseMap.set(hashCodeValue, caseLabels[index]) } - } else if (label.kind === "DefaultLabel") { - caseLabels[index] = defaultLabel; + } else if (label.kind === 'DefaultLabel') { + caseLabels[index] = defaultLabel } - }); - }); + }) + }) const caseLabelIndex: number[] = [] - let indexTracker = cg.code.length; - const positionOffset = cg.code.length; + let indexTracker = cg.code.length + const positionOffset = cg.code.length // **LOOKUPSWITCH Implementation** - cg.code.push(OPCODE.LOOKUPSWITCH); + cg.code.push(OPCODE.LOOKUPSWITCH) indexTracker++ // Ensure 4-byte alignment while (cg.code.length % 4 !== 0) { - cg.code.push(0); + cg.code.push(0) indexTracker++ } // Default jump target - cg.code.push(0, 0, 0, defaultLabel.offset); - caseLabelIndex.push(indexTracker + 3); - indexTracker += 4; - + cg.code.push(0, 0, 0, defaultLabel.offset) + caseLabelIndex.push(indexTracker + 3) + indexTracker += 4 // Number of case-value pairs - cg.code.push((hashCaseMap.size >> 24) & 0xff, (hashCaseMap.size >> 16) & 0xff, - (hashCaseMap.size >> 8) & 0xff, hashCaseMap.size & 0xff); - indexTracker += 4; + cg.code.push( + (hashCaseMap.size >> 24) & 0xff, + (hashCaseMap.size >> 16) & 0xff, + (hashCaseMap.size >> 8) & 0xff, + hashCaseMap.size & 0xff + ) + indexTracker += 4 // Populate LOOKUPSWITCH hashCaseMap.forEach((label, hashCode) => { - cg.code.push(hashCode >> 24, (hashCode >> 16) & 0xff, (hashCode >> 8) & 0xff, hashCode & 0xff); - cg.code.push(0, 0, 0, label.offset); - caseLabelIndex.push(indexTracker + 7); - indexTracker += 8; - }); + cg.code.push( + hashCode >> 24, + (hashCode >> 16) & 0xff, + (hashCode >> 8) & 0xff, + hashCode & 0xff + ) + cg.code.push(0, 0, 0, label.offset) + caseLabelIndex.push(indexTracker + 7) + indexTracker += 8 + }) // **Case Handling** - let previousCase: SwitchCase | null = null; + let previousCase: SwitchCase | null = null - cases.filter((caseGroup) => - caseGroup.labels.some((label) => label.kind === "CaseLabel")) + cases + .filter(caseGroup => caseGroup.labels.some(label => label.kind === 'CaseLabel')) .forEach((caseGroup, index) => { - caseLabels[index].offset = cg.code.length; + caseLabels[index].offset = cg.code.length // Ensure statements exist - caseGroup.statements = caseGroup.statements || []; + caseGroup.statements = caseGroup.statements || [] // Handle fallthrough if (previousCase && (previousCase.statements?.length ?? 0) === 0) { - previousCase.labels.push(...caseGroup.labels); + previousCase.labels.push(...caseGroup.labels) } // **String Comparison for Collisions** - const caseValue = caseGroup.labels.find((label): label is CaseLabel => label.kind === "CaseLabel"); + const caseValue = caseGroup.labels.find( + (label): label is CaseLabel => label.kind === 'CaseLabel' + ) if (caseValue) { // TODO: check for actual String equality instead of just rely on hashCode equality // (see the commented out code below) @@ -1575,49 +1582,51 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi // cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "equals", "(Ljava/lang/Object;)Z") // ); // - const caseEndLabel = cg.generateNewLabel(); + const caseEndLabel = cg.generateNewLabel() // cg.addBranchInstr(OPCODE.IFEQ, caseEndLabel); // Compile case statements - caseGroup.statements.forEach((statement) => { - const { stackSize } = compile(statement, cg); - maxStack = Math.max(maxStack, stackSize); - }); + caseGroup.statements.forEach(statement => { + const { stackSize } = compile(statement, cg) + maxStack = Math.max(maxStack, stackSize) + }) - caseEndLabel.offset = cg.code.length; + caseEndLabel.offset = cg.code.length } - previousCase = caseGroup; - }); + previousCase = caseGroup + }) // **Default Case Handling** - defaultLabel.offset = cg.code.length; - const defaultCase = cases.find((caseGroup) => - caseGroup.labels.some((label) => label.kind === "DefaultLabel")); + defaultLabel.offset = cg.code.length + const defaultCase = cases.find(caseGroup => + caseGroup.labels.some(label => label.kind === 'DefaultLabel') + ) if (defaultCase) { - defaultCase.statements = defaultCase.statements || []; - defaultCase.statements.forEach((statement) => { - const { stackSize } = compile(statement, cg); - maxStack = Math.max(maxStack, stackSize); - }); + defaultCase.statements = defaultCase.statements || [] + defaultCase.statements.forEach(statement => { + const { stackSize } = compile(statement, cg) + maxStack = Math.max(maxStack, stackSize) + }) } - cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset; + cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset for (let i = 1; i < caseLabelIndex.length; i++) { cg.code[caseLabelIndex[i]] = caseLabels[i - 1].offset - positionOffset } - endLabel.offset = cg.code.length; - + endLabel.offset = cg.code.length } else { - throw new Error(`Switch statements only support byte, short, int, char, or String types. Found: ${resultType}`); + throw new Error( + `Switch statements only support byte, short, int, char, or String types. Found: ${resultType}` + ) } - cg.switchLabels.pop(); + cg.switchLabels.pop() - return { stackSize: maxStack, resultType: EMPTY_TYPE }; + return { stackSize: maxStack, resultType: EMPTY_TYPE } } } @@ -1630,6 +1639,7 @@ class CodeGenerator { loopLabels: Label[][] = [] switchLabels: Label[] = [] code: number[] = [] + currentClass: string constructor(symbolTable: SymbolTable, constantPoolManager: ConstantPoolManager) { this.symbolTable = symbolTable @@ -1660,8 +1670,9 @@ class CodeGenerator { } } - generateCode(methodNode: MethodDeclaration) { + generateCode(currentClass: string, methodNode: MethodDeclaration) { this.symbolTable.extend() + this.currentClass = currentClass if (!methodNode.methodModifier.includes('static')) { this.maxLocals++ } @@ -1684,11 +1695,13 @@ class CodeGenerator { if (methodNode.methodHeader.identifier === '') { this.stackSize = Math.max(this.stackSize, 1) + const parentClass = + this.symbolTable.queryClass(currentClass).parentClassName || 'java/lang/Object' this.code.push( OPCODE.ALOAD_0, OPCODE.INVOKESPECIAL, 0, - this.constantPoolManager.indexMethodrefInfo('java/lang/Object', '', '()V') + this.constantPoolManager.indexMethodrefInfo(parentClass, '', '()V') ) } @@ -1729,8 +1742,9 @@ class CodeGenerator { export function generateCode( symbolTable: SymbolTable, constantPoolManager: ConstantPoolManager, + currentClass: string, methodNode: MethodDeclaration ) { const codeGenerator = new CodeGenerator(symbolTable, constantPoolManager) - return codeGenerator.generateCode(methodNode) + return codeGenerator.generateCode(currentClass, methodNode) } diff --git a/src/compiler/compiler.ts b/src/compiler/compiler.ts index a3566173..84ca6ded 100644 --- a/src/compiler/compiler.ts +++ b/src/compiler/compiler.ts @@ -31,35 +31,55 @@ export class Compiler { private methods: Array private attributes: Array private className: string + private parentClassName: string constructor() { this.setup() } private setup() { + this.symbolTable = new SymbolTable() this.constantPoolManager = new ConstantPoolManager() + } + + private resetClassFileState() { this.interfaces = [] this.fields = [] this.methods = [] this.attributes = [] - this.symbolTable = new SymbolTable() } compile(ast: AST) { this.setup() this.symbolTable.handleImports(ast.importDeclarations) const classFiles: Array = [] - ast.topLevelClassOrInterfaceDeclarations.forEach(x => classFiles.push(this.compileClass(x))) - return classFiles[0] + + ast.topLevelClassOrInterfaceDeclarations.forEach((decl, index) => { + const className = decl.typeIdentifier + const parentClassName = decl.sclass ? decl.sclass : 'java/lang/Object' + const accessFlags = generateClassAccessFlags(decl.classModifier) + this.symbolTable.insertClassInfo( + { name: className, accessFlags: accessFlags, parentClassName: parentClassName }, + true + ) + }) + + ast.topLevelClassOrInterfaceDeclarations.forEach(decl => { + this.resetClassFileState() + const classFile = this.compileClass(decl) + classFiles.push(classFile) + }) + + return classFiles } private compileClass(classNode: ClassDeclaration): ClassFile { - const parentClassName = 'java/lang/Object' this.className = classNode.typeIdentifier + this.parentClassName = classNode.sclass ? classNode.sclass : 'java/lang/Object' const accessFlags = generateClassAccessFlags(classNode.classModifier) - this.symbolTable.insertClassInfo({ name: this.className, accessFlags: accessFlags }) + this.symbolTable.insertClassInfo({ name: this.className, accessFlags: accessFlags }, false) - const superClassIndex = this.constantPoolManager.indexClassInfo(parentClassName) + const superClassIndex = this.constantPoolManager.indexClassInfo(this.parentClassName) const thisClassIndex = this.constantPoolManager.indexClassInfo(this.className) this.constantPoolManager.indexUtf8Info('Code') this.handleClassBody(classNode.classBody) @@ -153,7 +173,7 @@ export class Compiler { this.symbolTable.insertFieldInfo({ name: v.variableDeclaratorId, accessFlags: accessFlags, - parentClassName: this.className, + parentClassName: this.parentClassName, typeName: fullType, typeDescriptor: typeDescriptor }) @@ -170,8 +190,9 @@ export class Compiler { this.symbolTable.insertMethodInfo({ name: methodName, accessFlags: generateMethodAccessFlags(methodNode.methodModifier), - parentClassName: this.className, - typeDescriptor: descriptor + parentClassName: this.parentClassName, + typeDescriptor: descriptor, + className: this.className }) } @@ -183,8 +204,9 @@ export class Compiler { this.symbolTable.insertMethodInfo({ name: '', accessFlags: generateMethodAccessFlags(constructor.constructorModifier), - parentClassName: this.className, - typeDescriptor: descriptor + parentClassName: this.parentClassName, + typeDescriptor: descriptor, + className: this.className }) } @@ -199,7 +221,9 @@ export class Compiler { const descriptorIndex = this.constantPoolManager.indexUtf8Info(descriptor) const attributes: Array = [] - attributes.push(generateCode(this.symbolTable, this.constantPoolManager, methodNode)) + attributes.push( + generateCode(this.symbolTable, this.constantPoolManager, this.className, methodNode) + ) this.methods.push({ accessFlags: generateMethodAccessFlags(methodNode.methodModifier), diff --git a/src/compiler/error.ts b/src/compiler/error.ts index 333f687b..94605f15 100644 --- a/src/compiler/error.ts +++ b/src/compiler/error.ts @@ -18,7 +18,7 @@ export class SymbolRedeclarationError extends CompileError { export class SymbolCannotBeResolvedError extends CompileError { constructor(token: string, fullName: string) { - super('cannot resolve symbol ' + '"' + token + '"' + ' in' + '"' + fullName + '"') + super('cannot resolve symbol ' + '"' + token + '"' + ' in ' + '"' + fullName + '"') } } @@ -33,3 +33,21 @@ export class ConstructNotSupportedError extends CompileError { super('"' + name + '"' + ' is currently not supported by the compiler') } } + +export class MethodNotFoundError extends CompileError { + constructor(methodName: string, className: string) { + super(`Method ${methodName} not found in inheritance chain of ${className}`) + } +} + +export class NoMethodMatchingSignatureError extends CompileError { + constructor(signature: string) { + super(`No method matching signature ${signature}) found.`) + } +} + +export class AmbiguousMethodCallError extends CompileError { + constructor(signature: string) { + super(`Ambiguous method call: ${signature}`) + } +} diff --git a/src/compiler/import/lib-info.ts b/src/compiler/import/lib-info.ts index d334e765..8db187d6 100644 --- a/src/compiler/import/lib-info.ts +++ b/src/compiler/import/lib-info.ts @@ -1,61 +1,60 @@ export const rawLibInfo = { - "packages": [ + packages: [ { - "name": "java.lang", - "classes": [ + name: 'java.lang', + classes: [ { - "name": "public final java.lang.String" + name: 'public final java.lang.String' }, { - "name": "public final java.lang.System", - "fields": [ - "public static final java.io.PrintStream out" - ] + name: 'public final java.lang.Object' + }, + { + name: 'public final java.lang.System', + fields: ['public static final java.io.PrintStream out'] }, { - "name": "public final java.lang.Math", - "methods": [ - "public static int max(int,int)", - "public static int min(int,int)", - "public static double log10(double)" + name: 'public final java.lang.Math', + methods: [ + 'public static int max(int,int)', + 'public static int min(int,int)', + 'public static double log10(double)' ] } ] }, { - "name": "java.io", - "classes": [ + name: 'java.io', + classes: [ { - "name": "public java.io.PrintStream", - "methods": [ - "public void println(java.lang.String)", - "public void println(int)", - "public void println(long)", - "public void println(float)", - "public void println(double)", - "public void println(char)", - "public void println(boolean)", - "public void print(java.lang.String)", - "public void print(int)", - "public void print(long)", - "public void print(float)", - "public void print(double)", - "public void print(char)", - "public void print(boolean)" + name: 'public java.io.PrintStream', + methods: [ + 'public void println(java.lang.String)', + 'public void println(int)', + 'public void println(long)', + 'public void println(float)', + 'public void println(double)', + 'public void println(char)', + 'public void println(boolean)', + 'public void print(java.lang.String)', + 'public void print(int)', + 'public void print(long)', + 'public void print(float)', + 'public void print(double)', + 'public void print(char)', + 'public void print(boolean)' ] } ] }, { - "name": "java.util", - "classes": [ + name: 'java.util', + classes: [ { - "name": "public java.util.Arrays", - "methods": [ - "public static java.lang.String toString(int[])" - ] + name: 'public java.util.Arrays', + methods: ['public static java.lang.String toString(int[])'] } ] } ] -} \ No newline at end of file +} diff --git a/src/compiler/index.ts b/src/compiler/index.ts index 52a93725..3ad4eb1f 100644 --- a/src/compiler/index.ts +++ b/src/compiler/index.ts @@ -5,12 +5,12 @@ import { Compiler } from './compiler' import { javaPegGrammar } from './grammar' import { peggyFunctions } from './peggy-functions' -export const compile = (ast: AST): ClassFile => { +export const compile = (ast: AST): Array => { const compiler = new Compiler() return compiler.compile(ast) } -export const compileFromSource = (javaProgram: string): ClassFile => { +export const compileFromSource = (javaProgram: string): Array => { const parser = peggy.generate(peggyFunctions + javaPegGrammar, { allowedStartRules: ['CompilationUnit'], cache: true diff --git a/src/compiler/symbol-table.ts b/src/compiler/symbol-table.ts index 37c0620e..89dab0a5 100644 --- a/src/compiler/symbol-table.ts +++ b/src/compiler/symbol-table.ts @@ -63,6 +63,7 @@ export interface MethodInfo { accessFlags: number parentClassName: string typeDescriptor: string + className: string } export interface VariableInfo { @@ -99,12 +100,16 @@ export class SymbolTable { private setup() { libraries.forEach(p => { - this.importedPackages.push(p.packageName + '/') + if (this.importedPackages.findIndex(e => e == p.packageName + '/') == -1) + this.importedPackages.push(p.packageName + '/') p.classes.forEach(c => { - this.insertClassInfo({ - name: c.className, - accessFlags: generateClassAccessFlags(c.accessFlags) - }) + this.insertClassInfo( + { + name: c.className, + accessFlags: generateClassAccessFlags(c.accessFlags) + }, + true + ) c.fields.forEach(f => this.insertFieldInfo({ name: f.fieldName, @@ -119,7 +124,8 @@ export class SymbolTable { name: m.methodName, accessFlags: generateMethodAccessFlags(m.accessFlags), parentClassName: c.className, - typeDescriptor: this.generateMethodDescriptor(m.argsTypeName, m.returnTypeName) + typeDescriptor: this.generateMethodDescriptor(m.argsTypeName, m.returnTypeName), + className: c.className }) ) this.returnToRoot() @@ -168,7 +174,7 @@ export class SymbolTable { this.curTable = this.tables[this.curIdx] } - insertClassInfo(info: ClassInfo) { + insertClassInfo(info: ClassInfo, atRoot: boolean) { const key = generateSymbol(info.name, SymbolType.CLASS) if (this.curTable.has(key)) { @@ -181,7 +187,12 @@ export class SymbolTable { } this.curTable.set(key, symbolNode) - this.tables[++this.curIdx] = symbolNode.children + // this logic will need to be modified for inner classes + if (atRoot) { + this.tables[++this.curIdx] = symbolNode.children + } else { + this.tables[++this.curIdx] = this.getNewTable() + } this.curTable = this.tables[this.curIdx] this.curClassIdx = this.curIdx } @@ -189,6 +200,7 @@ export class SymbolTable { insertFieldInfo(info: FieldInfo) { const key = generateSymbol(info.name, SymbolType.FIELD) + this.curTable = this.tables[this.curIdx] if (this.curTable.has(key)) { throw new SymbolRedeclarationError(info.name) } @@ -203,6 +215,7 @@ export class SymbolTable { insertMethodInfo(info: MethodInfo) { const key = generateSymbol(info.name, SymbolType.METHOD) + this.curTable = this.tables[this.curIdx] if (!this.curTable.has(key)) { const symbolNode: SymbolNode = { info: [info], @@ -235,6 +248,7 @@ export class SymbolTable { info: info, children: this.getNewTable() } + this.curTable = this.tables[this.curIdx] this.curTable.set(key, symbolNode) } @@ -266,6 +280,36 @@ export class SymbolTable { throw new SymbolNotFoundError(name) } + private getClassTable(name: string): Table { + let key = generateSymbol(name, SymbolType.CLASS) + for (let i = this.curIdx; i > 0; i--) { + const table = this.tables[i] + if (table.has(key)) { + return table.get(key)!.children + } + } + + const root = this.tables[0] + if (this.importedClassMap.has(name)) { + const fullName = this.importedClassMap.get(name)! + key = generateSymbol(fullName, SymbolType.CLASS) + if (root.has(key)) { + return root.get(key)!.children + } + } + + let p: string + for (p of this.importedPackages) { + const fullName = p + name + key = generateSymbol(fullName, SymbolType.CLASS) + if (root.has(key)) { + return root.get(key)!.children + } + } + + throw new SymbolNotFoundError(name) + } + private querySymbol(name: string, symbolType: SymbolType): Array { let curTable = this.getNewTable() const symbolInfos: Array = [] @@ -287,8 +331,7 @@ export class SymbolTable { if (token === 'this') { curTable = this.tables[this.curClassIdx] } else { - const key = generateSymbol(this.queryClass(token).name, SymbolType.CLASS) - curTable = this.tables[0].get(key)!.children + curTable = this.getClassTable(token) } } else if (i < len - 1) { const key = generateSymbol(token, SymbolType.FIELD) @@ -299,8 +342,7 @@ export class SymbolTable { symbolInfos.push(node.info) const typeName = (node.info as FieldInfo).typeName - const type = generateSymbol(this.queryClass(typeName).name, SymbolType.CLASS) - curTable = this.tables[0].get(type)!.children + curTable = this.getClassTable(typeName) } else { const key = generateSymbol(token, symbolType) const node = curTable.get(key) @@ -330,10 +372,20 @@ export class SymbolTable { } } + const results: Array = [] const key2 = generateSymbol(name, SymbolType.METHOD) - const table = this.tables[this.curClassIdx] - if (table.has(key2)) { - return [table.get(key2)!.info] + for (let i = this.curIdx; i > 0; i--) { + const table = this.tables[i] + if (table.has(key2)) { + const methodInfos = table.get(key2)!.info as MethodInfos + for (const methodInfo of methodInfos) { + results.push(methodInfo) + } + } + } + + if (results.length > 0) { + return results } throw new InvalidMethodCallError(name) } @@ -346,7 +398,7 @@ export class SymbolTable { const key1 = generateSymbol(name, SymbolType.VARIABLE) const key2 = generateSymbol(name, SymbolType.FIELD) - for (let i = this.curIdx; i >= 0; i--) { + for (let i = this.curIdx; i > this.curClassIdx; i--) { const table = this.tables[i] if (table.has(key1)) { return (table.get(key1) as SymbolNode).info as VariableInfo From 374c8c14be630c9706992747b0accfb7a9ba9750 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Thu, 3 Apr 2025 11:36:00 +0800 Subject: [PATCH 21/29] Add test cases for method overriding. --- .../__tests__/tests/methodOverriding.test.ts | 217 ++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 src/compiler/__tests__/tests/methodOverriding.test.ts diff --git a/src/compiler/__tests__/tests/methodOverriding.test.ts b/src/compiler/__tests__/tests/methodOverriding.test.ts new file mode 100644 index 00000000..e3542d1e --- /dev/null +++ b/src/compiler/__tests__/tests/methodOverriding.test.ts @@ -0,0 +1,217 @@ +import { runTest, testCase } from '../__utils__/test-utils' + +const testCases: testCase[] = [ + { + comment: 'Basic method overriding', + program: ` + class Parent1 { + public void show() { + System.out.println("Parent show"); + } + } + class Child1 extends Parent1 { + public void show() { + System.out.println("Child show"); + } + } + public class Main1 { + public static void main(String[] args) { + Parent1 p = new Parent1(); + p.show(); // Parent show + Child1 c = new Child1(); + c.show(); // Child show + Parent1 ref = new Child1(); + ref.show(); // Child show (dynamic dispatch) + } + } + `, + expectedLines: ['Parent show', 'Child show', 'Child show'] + }, + { + comment: 'Overriding with different access modifiers', + program: ` + class Parent { + protected void display() { + System.out.println("Parent display"); + } + } + class Child extends Parent { + public void display() { // Increased visibility + System.out.println("Child display"); + } + } + public class Main { + public static void main(String[] args) { + Parent ref = new Child(); + ref.display(); // Child display + } + } + `, + expectedLines: ['Child display'] + }, + { + comment: 'Method overriding with return type covariance', + program: ` + class Parent { + public Number getValue() { + return 10; + } + } + class Child extends Parent { + public Integer getValue() { + return 20; + } + } + public class Main { + public static void main(String[] args) { + Parent ref = new Child(); + System.out.println(ref.getValue()); // 20 + } + } + `, + expectedLines: ['20'] + }, + { + comment: 'Overriding with multiple levels of inheritance', + program: ` + class GrandParent { + public void greet() { + System.out.println("Hello from GrandParent"); + } + } + class Parent extends GrandParent { + public void greet() { + System.out.println("Hello from Parent"); + } + } + class Child extends Parent { + public void greet() { + System.out.println("Hello from Child"); + } + } + public class Main { + public static void main(String[] args) { + GrandParent ref1 = new GrandParent(); + ref1.greet(); // GrandParent + GrandParent ref2 = new Parent(); + ref2.greet(); // Parent + GrandParent ref3 = new Child(); + ref3.greet(); // Child + } + } + `, + expectedLines: ['Hello from GrandParent', 'Hello from Parent', 'Hello from Child'] + }, + { + comment: 'Overriding and method hiding with static methods', + program: ` + class Parent { + public static void staticMethod() { + System.out.println("Parent static method"); + } + public void instanceMethod() { + System.out.println("Parent instance method"); + } + } + class Child extends Parent { + public static void staticMethod() { + System.out.println("Child static method"); + } + public void instanceMethod() { + System.out.println("Child instance method"); + } + } + public class Main { + public static void main(String[] args) { + Parent.staticMethod(); // Parent static method + Child.staticMethod(); // Child static method + Parent ref = new Child(); + ref.instanceMethod(); // Child instance method + } + } + `, + expectedLines: ['Parent static method', 'Child static method', 'Child instance method'] + }, + { + comment: 'Overriding final methods (should cause compilation error)', + program: ` + class Parent { + public final void show() { + System.out.println("Final method in Parent"); + } + } + class Child extends Parent { + // public void show() {} // Uncommenting should cause compilation error + } + public class Main { + public static void main(String[] args) { + Parent p = new Parent(); + p.show(); // Final method in Parent + } + } + `, + expectedLines: ['Final method in Parent'] + }, + { + comment: 'Overriding in a deep class hierarchy', + program: ` + class A { + public void test() { + System.out.println("A test"); + } + } + class B extends A { + public void test() { + System.out.println("B test"); + } + } + class C extends B { + public void test() { + System.out.println("C test"); + } + } + class D extends C { + public void test() { + System.out.println("D test"); + } + } + public class Main { + public static void main(String[] args) { + A ref = new D(); + ref.test(); // D test + } + } + `, + expectedLines: ['D test'] + }, + { + comment: 'Overriding private methods (should not override, treated as new method)', + program: ` + class Parent { + private void secret() { + System.out.println("Parent secret"); + } + } + class Child extends Parent { + public void secret() { + System.out.println("Child secret"); + } + } + public class Main { + public static void main(String[] args) { + Child c = new Child(); + c.secret(); // Child secret + } + } + `, + expectedLines: ['Child secret'] + } +] + +export const methodOverridingTest = () => + describe('method overriding', () => { + for (let testCase of testCases) { + const { comment, program, expectedLines } = testCase + it(comment, () => runTest(program, expectedLines)) + } + }) From 35bcce8903c19f5f9bd7cfb27c098e6542e7710e Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Thu, 3 Apr 2025 14:56:39 +0800 Subject: [PATCH 22/29] Fix bugs in symbol-table.ts and compiler.ts logic. Add method invocation logic to code-generator.ts --- .../__tests__/tests/methodOverriding.test.ts | 12 ++-- src/compiler/code-generator.ts | 55 +++++++++++++------ src/compiler/compiler.ts | 12 ++-- src/compiler/error.ts | 6 -- src/compiler/symbol-table.ts | 17 ++---- 5 files changed, 56 insertions(+), 46 deletions(-) diff --git a/src/compiler/__tests__/tests/methodOverriding.test.ts b/src/compiler/__tests__/tests/methodOverriding.test.ts index e3542d1e..de89f213 100644 --- a/src/compiler/__tests__/tests/methodOverriding.test.ts +++ b/src/compiler/__tests__/tests/methodOverriding.test.ts @@ -4,23 +4,23 @@ const testCases: testCase[] = [ { comment: 'Basic method overriding', program: ` - class Parent1 { + class Parent { public void show() { System.out.println("Parent show"); } } - class Child1 extends Parent1 { + class Child extends Parent { public void show() { System.out.println("Child show"); } } - public class Main1 { + public class Main { public static void main(String[] args) { - Parent1 p = new Parent1(); + Parent p = new Parent(); p.show(); // Parent show - Child1 c = new Child1(); + Child c = new Child(); c.show(); // Child show - Parent1 ref = new Child1(); + Parent ref = new Child(); ref.show(); // Child show (dynamic dispatch) } } diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 24b9d76c..fa50fb31 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -35,7 +35,6 @@ import { ConstantPoolManager } from './constant-pool-manager' import { AmbiguousMethodCallError, ConstructNotSupportedError, - MethodNotFoundError, NoMethodMatchingSignatureError } from './error' import { FieldInfo, MethodInfos, SymbolInfo, SymbolTable, VariableInfo } from './symbol-table' @@ -206,10 +205,22 @@ type CompileResult = { } const EMPTY_TYPE: string = '' -function areClassTypesCompatible(fromType: string, toType: string): boolean { +function areClassTypesCompatible(fromType: string, toType: string, cg: CodeGenerator): boolean { const cleanFrom = fromType.replace(/^L|;$/g, '') const cleanTo = toType.replace(/^L|;$/g, '') - return cleanFrom === cleanTo + if (cleanFrom === cleanTo) return true; + + try { + let current = cg.symbolTable.queryClass(cleanFrom); + while (current.parentClassName) { + const parentClean = current.parentClassName; + if (parentClean === cleanTo) return true; + current = cg.symbolTable.queryClass(parentClean); + } + } catch (e) { + return false; + } + return false; } function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator): number { @@ -218,7 +229,7 @@ function handleImplicitTypeConversion(fromType: string, toType: string, cg: Code } if (fromType.startsWith('L') || toType.startsWith('L')) { - if (areClassTypesCompatible(fromType, toType) || fromType === '') { + if (areClassTypesCompatible(fromType, toType, cg) || fromType === '') { return 0 } throw new Error(`Unsupported class type conversion: ${fromType} -> ${toType}`) @@ -350,11 +361,11 @@ function getExpressionType(node: Node, cg: CodeGenerator): string { return resultType } -function isSubtype(fromType: string, toType: string): boolean { +function isSubtype(fromType: string, toType: string, cg: CodeGenerator): boolean { return ( fromType === toType || typeConversionsImplicit[`${fromType}->${toType}`] !== undefined || - areClassTypesCompatible(fromType, toType) + areClassTypesCompatible(fromType, toType, cg) ) } @@ -738,13 +749,12 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi }) const argDescriptor = '(' + argTypes.join('') + ')' - const symbolInfos = cg.symbolTable.queryMethod('') - const methodInfos = symbolInfos[symbolInfos.length - 1] as MethodInfos + const methodInfos = cg.symbolTable.queryMethod('') as MethodInfos for (let i = 0; i < methodInfos.length; i++) { const methodInfo = methodInfos[i] - if (methodInfo.typeDescriptor.includes(argDescriptor)) { + if (methodInfo.typeDescriptor.includes(argDescriptor) && methodInfo.className == id) { const method = cg.constantPoolManager.indexMethodrefInfo( - methodInfo.parentClassName, + methodInfo.className, methodInfo.name, methodInfo.typeDescriptor ) @@ -778,22 +788,31 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const n = node as MethodInvocation let maxStack = 1 let resultType = EMPTY_TYPE - const candidateMethods: MethodInfos = [] + let candidateMethods: MethodInfos = [] + let unqualifiedCall = false - // TODO: Write logic to get candidateMethods // --- Handle super. calls --- if (n.identifier.startsWith('super.')) { + candidateMethods = cg.symbolTable.queryMethod(n.identifier.slice(6)).pop() as MethodInfos + candidateMethods.filter(method => + method.className == cg.symbolTable.queryClass(cg.currentClass).parentClassName) + cg.code.push(OPCODE.ALOAD, 0); } // --- Handle qualified calls (e.g. System.out.println or p.show) --- else if (n.identifier.includes('.')) { + // TODO: Load target object before method call + candidateMethods = cg.symbolTable.queryMethod(n.identifier).pop() as MethodInfos } - // --- Handle unqualified calls (including this.method()) --- + // --- Handle unqualified calls --- else { + candidateMethods = cg.symbolTable.queryMethod(n.identifier) as MethodInfos + unqualifiedCall = true; } // Filter candidate methods by matching the argument list. const argDescs = n.argumentList.map(arg => getExpressionType(arg, cg)) const methodMatches: MethodInfos = [] + for (let i = 0; i < candidateMethods.length; i++) { const m = candidateMethods[i] const fullDesc = m.typeDescriptor // e.g., "(Ljava/lang/String;C)V" @@ -805,7 +824,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const argType = argDescs[i] // Allow B/S to match int. if ((argType === 'B' || argType === 'S') && params[i] === 'I') continue - if (!isSubtype(argType, params[i])) { + if (!isSubtype(argType, params[i], cg)) { match = false break } @@ -829,17 +848,21 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi .slice(1, methodMatches[i].typeDescriptor.indexOf(')')) .match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] if ( - candParams.map((p, idx) => isSubtype(p, currParams[idx])).reduce((a, b) => a && b, true) + candParams.map((p, idx) => isSubtype(p, currParams[idx], cg)).reduce((a, b) => a && b, true) ) { selectedMethod = methodMatches[i] } else if ( - !currParams.map((p, idx) => isSubtype(p, candParams[idx])).reduce((a, b) => a && b, true) + !currParams.map((p, idx) => isSubtype(p, candParams[idx], cg)).reduce((a, b) => a && b, true) ) { throw new AmbiguousMethodCallError(n.identifier + argDescs.join(',')) } } } + if (unqualifiedCall && !(selectedMethod.accessFlags & FIELD_FLAGS.ACC_STATIC)) { + cg.code.push(OPCODE.ALOAD, 0) + } + // Compile each argument. const fullDescriptor = selectedMethod.typeDescriptor const paramPart = fullDescriptor.slice(1, fullDescriptor.indexOf(')')) diff --git a/src/compiler/compiler.ts b/src/compiler/compiler.ts index 84ca6ded..6cd19f5d 100644 --- a/src/compiler/compiler.ts +++ b/src/compiler/compiler.ts @@ -39,10 +39,10 @@ export class Compiler { private setup() { this.symbolTable = new SymbolTable() - this.constantPoolManager = new ConstantPoolManager() } private resetClassFileState() { + this.constantPoolManager = new ConstantPoolManager() this.interfaces = [] this.fields = [] this.methods = [] @@ -54,14 +54,13 @@ export class Compiler { this.symbolTable.handleImports(ast.importDeclarations) const classFiles: Array = [] - ast.topLevelClassOrInterfaceDeclarations.forEach((decl, index) => { + ast.topLevelClassOrInterfaceDeclarations.forEach(decl => { const className = decl.typeIdentifier const parentClassName = decl.sclass ? decl.sclass : 'java/lang/Object' const accessFlags = generateClassAccessFlags(decl.classModifier) this.symbolTable.insertClassInfo( - { name: className, accessFlags: accessFlags, parentClassName: parentClassName }, - true - ) + { name: className, accessFlags: accessFlags, parentClassName: parentClassName }) + this.symbolTable.returnToRoot() }) ast.topLevelClassOrInterfaceDeclarations.forEach(decl => { @@ -77,7 +76,8 @@ export class Compiler { this.className = classNode.typeIdentifier this.parentClassName = classNode.sclass ? classNode.sclass : 'java/lang/Object' const accessFlags = generateClassAccessFlags(classNode.classModifier) - this.symbolTable.insertClassInfo({ name: this.className, accessFlags: accessFlags }, false) + this.symbolTable.extend() + this.symbolTable.insertClassInfo({ name: this.className, accessFlags: accessFlags }) const superClassIndex = this.constantPoolManager.indexClassInfo(this.parentClassName) const thisClassIndex = this.constantPoolManager.indexClassInfo(this.className) diff --git a/src/compiler/error.ts b/src/compiler/error.ts index 94605f15..3e3e906e 100644 --- a/src/compiler/error.ts +++ b/src/compiler/error.ts @@ -34,12 +34,6 @@ export class ConstructNotSupportedError extends CompileError { } } -export class MethodNotFoundError extends CompileError { - constructor(methodName: string, className: string) { - super(`Method ${methodName} not found in inheritance chain of ${className}`) - } -} - export class NoMethodMatchingSignatureError extends CompileError { constructor(signature: string) { super(`No method matching signature ${signature}) found.`) diff --git a/src/compiler/symbol-table.ts b/src/compiler/symbol-table.ts index 89dab0a5..907d9119 100644 --- a/src/compiler/symbol-table.ts +++ b/src/compiler/symbol-table.ts @@ -107,9 +107,7 @@ export class SymbolTable { { name: c.className, accessFlags: generateClassAccessFlags(c.accessFlags) - }, - true - ) + }) c.fields.forEach(f => this.insertFieldInfo({ name: f.fieldName, @@ -137,7 +135,7 @@ export class SymbolTable { return new Map() } - private returnToRoot() { + public returnToRoot() { this.tables = [this.tables[0]] this.curTable = this.tables[0] this.curIdx = 0 @@ -174,7 +172,7 @@ export class SymbolTable { this.curTable = this.tables[this.curIdx] } - insertClassInfo(info: ClassInfo, atRoot: boolean) { + insertClassInfo(info: ClassInfo) { const key = generateSymbol(info.name, SymbolType.CLASS) if (this.curTable.has(key)) { @@ -187,12 +185,7 @@ export class SymbolTable { } this.curTable.set(key, symbolNode) - // this logic will need to be modified for inner classes - if (atRoot) { - this.tables[++this.curIdx] = symbolNode.children - } else { - this.tables[++this.curIdx] = this.getNewTable() - } + this.tables[++this.curIdx] = symbolNode.children this.curTable = this.tables[this.curIdx] this.curClassIdx = this.curIdx } @@ -282,7 +275,7 @@ export class SymbolTable { private getClassTable(name: string): Table { let key = generateSymbol(name, SymbolType.CLASS) - for (let i = this.curIdx; i > 0; i--) { + for (let i = this.curIdx; i >= 0; i--) { const table = this.tables[i] if (table.has(key)) { return table.get(key)!.children From d4d4bf15954cb97b45145f584de3fcf66c1c071a Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Thu, 3 Apr 2025 17:21:22 +0800 Subject: [PATCH 23/29] Fix method invocation logic in target loading for qualified calls --- src/compiler/code-generator.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index fa50fb31..e584db8b 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -800,7 +800,10 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } // --- Handle qualified calls (e.g. System.out.println or p.show) --- else if (n.identifier.includes('.')) { - // TODO: Load target object before method call + const lastDot = n.identifier.lastIndexOf('.'); + const receiverStr = n.identifier.slice(0, lastDot); + const recvRes = compile({ kind: 'ExpressionName', name: receiverStr }, cg); + maxStack = Math.max(maxStack, recvRes.stackSize); candidateMethods = cg.symbolTable.queryMethod(n.identifier).pop() as MethodInfos } // --- Handle unqualified calls --- @@ -1220,7 +1223,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } const fieldInfo = fieldInfos[i] as FieldInfo const field = cg.constantPoolManager.indexFieldrefInfo( - fieldInfo.parentClassName, + fieldInfo.typeName, fieldInfo.name, fieldInfo.typeDescriptor ) From 1ca11986a5c3a51981d9cdd3674426ab960deab2 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Sun, 6 Apr 2025 12:06:34 +0800 Subject: [PATCH 24/29] Fix bugs in method invocation logic for static method calls, and in static field lookup logic. --- .../__tests__/tests/methodOverriding.test.ts | 32 ++++--------------- src/compiler/code-generator.ts | 26 +++++++++++---- src/compiler/compiler.ts | 2 +- src/compiler/symbol-table.ts | 8 ++--- 4 files changed, 31 insertions(+), 37 deletions(-) diff --git a/src/compiler/__tests__/tests/methodOverriding.test.ts b/src/compiler/__tests__/tests/methodOverriding.test.ts index de89f213..40ac00b6 100644 --- a/src/compiler/__tests__/tests/methodOverriding.test.ts +++ b/src/compiler/__tests__/tests/methodOverriding.test.ts @@ -49,28 +49,6 @@ const testCases: testCase[] = [ `, expectedLines: ['Child display'] }, - { - comment: 'Method overriding with return type covariance', - program: ` - class Parent { - public Number getValue() { - return 10; - } - } - class Child extends Parent { - public Integer getValue() { - return 20; - } - } - public class Main { - public static void main(String[] args) { - Parent ref = new Child(); - System.out.println(ref.getValue()); // 20 - } - } - `, - expectedLines: ['20'] - }, { comment: 'Overriding with multiple levels of inheritance', program: ` @@ -141,7 +119,7 @@ const testCases: testCase[] = [ } } class Child extends Parent { - // public void show() {} // Uncommenting should cause compilation error + public void show() {} // Uncommenting should cause compilation error } public class Main { public static void main(String[] args) { @@ -177,12 +155,14 @@ const testCases: testCase[] = [ } public class Main { public static void main(String[] args) { - A ref = new D(); - ref.test(); // D test + A ref1 = new D(); + B ref2 = new C(); + ref1.test(); // D test + ref2.test(); // C test } } `, - expectedLines: ['D test'] + expectedLines: ['D test', 'C test'] }, { comment: 'Overriding private methods (should not override, treated as new method)', diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index e584db8b..3bea9dbc 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -793,7 +793,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi // --- Handle super. calls --- if (n.identifier.startsWith('super.')) { - candidateMethods = cg.symbolTable.queryMethod(n.identifier.slice(6)).pop() as MethodInfos + candidateMethods = cg.symbolTable.queryMethod(n.identifier.slice(6)) as MethodInfos candidateMethods.filter(method => method.className == cg.symbolTable.queryClass(cg.currentClass).parentClassName) cg.code.push(OPCODE.ALOAD, 0); @@ -802,9 +802,18 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi else if (n.identifier.includes('.')) { const lastDot = n.identifier.lastIndexOf('.'); const receiverStr = n.identifier.slice(0, lastDot); - const recvRes = compile({ kind: 'ExpressionName', name: receiverStr }, cg); - maxStack = Math.max(maxStack, recvRes.stackSize); - candidateMethods = cg.symbolTable.queryMethod(n.identifier).pop() as MethodInfos + + if (receiverStr === 'this') { + candidateMethods = cg.symbolTable.queryMethod(n.identifier.slice(5)) as MethodInfos + console.debug(candidateMethods) + candidateMethods.filter(method => + method.className == cg.currentClass) + cg.code.push(OPCODE.ALOAD, 0); + } else { + const recvRes = compile({ kind: 'ExpressionName', name: receiverStr }, cg); + maxStack = Math.max(maxStack, recvRes.stackSize); + candidateMethods = cg.symbolTable.queryMethod(n.identifier).pop() as MethodInfos + } } // --- Handle unqualified calls --- else { @@ -1210,7 +1219,12 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } } - const info = cg.symbolTable.queryVariable(name) + let info: VariableInfo | SymbolInfo[] + try { + info = cg.symbolTable.queryVariable(name) + } catch (e) { + return { stackSize: 1, resultType: 'Ljava/lang/Class;' }; + } if (Array.isArray(info)) { const fieldInfos = info for (let i = 0; i < fieldInfos.length; i++) { @@ -1223,7 +1237,7 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } const fieldInfo = fieldInfos[i] as FieldInfo const field = cg.constantPoolManager.indexFieldrefInfo( - fieldInfo.typeName, + fieldInfo.parentClassName, fieldInfo.name, fieldInfo.typeDescriptor ) diff --git a/src/compiler/compiler.ts b/src/compiler/compiler.ts index 6cd19f5d..fd0bf3d0 100644 --- a/src/compiler/compiler.ts +++ b/src/compiler/compiler.ts @@ -173,7 +173,7 @@ export class Compiler { this.symbolTable.insertFieldInfo({ name: v.variableDeclaratorId, accessFlags: accessFlags, - parentClassName: this.parentClassName, + parentClassName: this.className, typeName: fullType, typeDescriptor: typeDescriptor }) diff --git a/src/compiler/symbol-table.ts b/src/compiler/symbol-table.ts index 907d9119..33db69f3 100644 --- a/src/compiler/symbol-table.ts +++ b/src/compiler/symbol-table.ts @@ -231,7 +231,7 @@ export class SymbolTable { insertVariableInfo(info: VariableInfo) { const key = generateSymbol(info.name, SymbolType.VARIABLE) - for (let i = this.curIdx; i > this.curClassIdx; i--) { + for (let i = this.curIdx; i >= this.curClassIdx; i--) { if (this.tables[i].has(key)) { throw new SymbolRedeclarationError(info.name) } @@ -312,7 +312,7 @@ export class SymbolTable { tokens.forEach((token, i) => { if (i === 0) { const key1 = generateSymbol(token, SymbolType.VARIABLE) - for (let i = this.curIdx; i > this.curClassIdx; i--) { + for (let i = this.curIdx; i >= this.curClassIdx; i--) { if (this.tables[i].has(key1)) { const node = this.tables[i].get(key1)! token = (node.info as VariableInfo).typeName @@ -359,7 +359,7 @@ export class SymbolTable { } const key1 = generateSymbol(name, SymbolType.VARIABLE) - for (let i = this.curIdx; i > this.curClassIdx; i--) { + for (let i = this.curIdx; i >= this.curClassIdx; i--) { if (this.tables[i].has(key1)) { throw new InvalidMethodCallError(name) } @@ -391,7 +391,7 @@ export class SymbolTable { const key1 = generateSymbol(name, SymbolType.VARIABLE) const key2 = generateSymbol(name, SymbolType.FIELD) - for (let i = this.curIdx; i > this.curClassIdx; i--) { + for (let i = this.curIdx; i >= this.curClassIdx; i--) { const table = this.tables[i] if (table.has(key1)) { return (table.get(key1) as SymbolNode).info as VariableInfo From 0bc9b53a0d69482b6f25f881160c8f3295372946 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Sun, 6 Apr 2025 18:00:48 +0800 Subject: [PATCH 25/29] Modify the compiler to return multiple class files with their file names. --- package.json | 7 +++++-- src/ClassFile/types/index.ts | 5 +++++ src/compiler/compiler.ts | 6 +++--- src/compiler/index.ts | 6 +++--- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/package.json b/package.json index 87093d66..cf2f2faf 100644 --- a/package.json +++ b/package.json @@ -3,7 +3,9 @@ "version": "1.0.13", "main": "dist/index.js", "types": "dist/index.d.ts", - "files": ["dist"], + "files": [ + "dist" + ], "repository": { "type": "git", "url": "git+https://github.com/source-academy/java-slang.git" @@ -40,5 +42,6 @@ "java-parser": "^2.0.5", "lodash": "^4.17.21", "peggy": "^4.0.2" - } + }, + "packageManager": "yarn@1.22.22+sha1.ac34549e6aa8e7ead463a7407e1c7390f61a6610" } diff --git a/src/ClassFile/types/index.ts b/src/ClassFile/types/index.ts index 4f963d44..71ccd814 100644 --- a/src/ClassFile/types/index.ts +++ b/src/ClassFile/types/index.ts @@ -3,6 +3,11 @@ import { ConstantInfo } from './constants' import { FieldInfo } from './fields' import { MethodInfo } from './methods' +export interface Class { + classFile: ClassFile + className: string +} + export interface ClassFile { magic: number minorVersion: number diff --git a/src/compiler/compiler.ts b/src/compiler/compiler.ts index fd0bf3d0..3e116800 100644 --- a/src/compiler/compiler.ts +++ b/src/compiler/compiler.ts @@ -1,4 +1,4 @@ -import { ClassFile } from '../ClassFile/types' +import { Class, ClassFile } from '../ClassFile/types' import { AST } from '../ast/types/packages-and-modules' import { ClassBodyDeclaration, @@ -52,7 +52,7 @@ export class Compiler { compile(ast: AST) { this.setup() this.symbolTable.handleImports(ast.importDeclarations) - const classFiles: Array = [] + const classFiles: Array = [] ast.topLevelClassOrInterfaceDeclarations.forEach(decl => { const className = decl.typeIdentifier @@ -66,7 +66,7 @@ export class Compiler { ast.topLevelClassOrInterfaceDeclarations.forEach(decl => { this.resetClassFileState() const classFile = this.compileClass(decl) - classFiles.push(classFile) + classFiles.push({classFile: classFile, className: this.className}) }) return classFiles diff --git a/src/compiler/index.ts b/src/compiler/index.ts index 3ad4eb1f..13202756 100644 --- a/src/compiler/index.ts +++ b/src/compiler/index.ts @@ -1,16 +1,16 @@ import * as peggy from 'peggy' import { AST } from '../ast/types/packages-and-modules' -import { ClassFile } from '../ClassFile/types' +import { Class } from '../ClassFile/types' import { Compiler } from './compiler' import { javaPegGrammar } from './grammar' import { peggyFunctions } from './peggy-functions' -export const compile = (ast: AST): Array => { +export const compile = (ast: AST): Array => { const compiler = new Compiler() return compiler.compile(ast) } -export const compileFromSource = (javaProgram: string): Array => { +export const compileFromSource = (javaProgram: string): Array => { const parser = peggy.generate(peggyFunctions + javaPegGrammar, { allowedStartRules: ['CompilationUnit'], cache: true From d2517b2a236ba4c9de10120ddea067a2b08e0efb Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Sun, 6 Apr 2025 18:24:01 +0800 Subject: [PATCH 26/29] Fix bug to not allow final methods to be overriden --- .../__tests__/__utils__/test-utils.ts | 6 +- .../__tests__/tests/methodOverriding.test.ts | 135 +++++++++++++++++- src/compiler/error.ts | 6 + src/compiler/symbol-table.ts | 27 +++- 4 files changed, 164 insertions(+), 10 deletions(-) diff --git a/src/compiler/__tests__/__utils__/test-utils.ts b/src/compiler/__tests__/__utils__/test-utils.ts index 39b6f156..382c3907 100644 --- a/src/compiler/__tests__/__utils__/test-utils.ts +++ b/src/compiler/__tests__/__utils__/test-utils.ts @@ -30,9 +30,9 @@ export function runTest(program: string, expectedLines: string[]) { console.log(inspect(ast, false, null, true)) } - const classFiles = compile(ast as AST) - for (let classFile of classFiles) { - binaryWriter.writeBinary(classFile, pathToTestDir) + const classes = compile(ast as AST) + for (let c of classes) { + binaryWriter.writeBinary(c.classFile, pathToTestDir) } const prevDir = process.cwd() diff --git a/src/compiler/__tests__/tests/methodOverriding.test.ts b/src/compiler/__tests__/tests/methodOverriding.test.ts index 40ac00b6..15a89097 100644 --- a/src/compiler/__tests__/tests/methodOverriding.test.ts +++ b/src/compiler/__tests__/tests/methodOverriding.test.ts @@ -119,7 +119,7 @@ const testCases: testCase[] = [ } } class Child extends Parent { - public void show() {} // Uncommenting should cause compilation error + // public void show() {} // Uncommenting should cause compilation error } public class Main { public static void main(String[] args) { @@ -185,6 +185,139 @@ const testCases: testCase[] = [ } `, expectedLines: ['Child secret'] + }, + { + comment: 'Using this to call an instance method', + program: ` + class Self { + public void print() { + System.out.println("Self print"); + } + public void callSelf() { + this.print(); + } + } + public class Main { + public static void main(String[] args) { + Self s = new Self(); + s.callSelf(); // Self print + } + } + `, + expectedLines: ['Self print'] + }, + { + comment: 'Using super to invoke parent method', + program: ` + class Base { + public void greet() { + System.out.println("Hello from Base"); + } + } + class Derived extends Base { + public void greet() { + super.greet(); + System.out.println("Hello from Derived"); + } + } + public class Main { + public static void main(String[] args) { + Derived d = new Derived(); + d.greet(); + // Expected: + // Hello from Base + // Hello from Derived + } + } + `, + expectedLines: ['Hello from Base', 'Hello from Derived'] + }, + { + comment: 'Polymorphic call with dynamic dispatch', + program: ` + class Animal { + public void speak() { + System.out.println("Animal sound"); + } + } + class Dog extends Animal { + public void speak() { + System.out.println("Bark"); + } + public void callSuper() { + super.speak(); + } + } + public class Main { + public static void main(String[] args) { + Dog d = new Dog(); + d.speak(); // Bark + d.callSuper(); // Animal sound + } + } + `, + expectedLines: ['Bark', 'Animal sound'] + }, + { + comment: 'Method overloading resolution', + program: ` + class Overload { + public void test(int a) { + System.out.println("int"); + } + public void test(double a) { + System.out.println("double"); + } + } + public class Main { + public static void main(String[] args) { + Overload o = new Overload(); + o.test(5); // int + o.test(5.0); // double + } + } + `, + expectedLines: ['int', 'double'] + }, + { + comment: 'Overriding on a superclass reference', + program: ` + class X { + public void foo() { + System.out.println("X foo"); + } + } + class Y extends X { + public void foo() { + System.out.println("Y foo"); + } + } + public class Main { + public static void main(String[] args) { + X x = new Y(); + x.foo(); // Y foo + } + } + `, + expectedLines: ['Y foo'] + }, + { + comment: 'Implicit conversion (byte to int)', + program: ` + class Implicit { + public void process(int a) { + System.out.println("Processed int"); + } + } + public class Main { + public static void main(String[] args) { + Implicit imp = new Implicit(); + byte b = (byte) 10; + imp.process(b); // Processed int + } + } + `, + expectedLines: ['Processed int'] } ] diff --git a/src/compiler/error.ts b/src/compiler/error.ts index 3e3e906e..1044d4fd 100644 --- a/src/compiler/error.ts +++ b/src/compiler/error.ts @@ -45,3 +45,9 @@ export class AmbiguousMethodCallError extends CompileError { super(`Ambiguous method call: ${signature}`) } } + +export class OverrideFinalMethodError extends CompileError { + constructor(name: string) { + super(`Cannot override final method ${name}`) + } +} \ No newline at end of file diff --git a/src/compiler/symbol-table.ts b/src/compiler/symbol-table.ts index 33db69f3..394ffd34 100644 --- a/src/compiler/symbol-table.ts +++ b/src/compiler/symbol-table.ts @@ -6,12 +6,13 @@ import { generateMethodAccessFlags } from './compiler-utils' import { - InvalidMethodCallError, + InvalidMethodCallError, OverrideFinalMethodError, SymbolCannotBeResolvedError, SymbolNotFoundError, SymbolRedeclarationError } from './error' import { libraries } from './import/libs' +import { METHOD_FLAGS } from '../ClassFile/types/methods' export const typeMap = new Map([ ['byte', 'B'], @@ -103,11 +104,10 @@ export class SymbolTable { if (this.importedPackages.findIndex(e => e == p.packageName + '/') == -1) this.importedPackages.push(p.packageName + '/') p.classes.forEach(c => { - this.insertClassInfo( - { - name: c.className, - accessFlags: generateClassAccessFlags(c.accessFlags) - }) + this.insertClassInfo({ + name: c.className, + accessFlags: generateClassAccessFlags(c.accessFlags) + }) c.fields.forEach(f => this.insertFieldInfo({ name: f.fieldName, @@ -208,6 +208,21 @@ export class SymbolTable { insertMethodInfo(info: MethodInfo) { const key = generateSymbol(info.name, SymbolType.METHOD) + for (let i = this.curClassIdx - 1; i > 0; i--) { + const parentTable = this.tables[i]; + if (parentTable.has(key)) { + const parentMethods = parentTable.get(key)!.info; + if (Array.isArray(parentMethods)) { + for (const m of parentMethods) { + if (m.typeDescriptor === info.typeDescriptor && (m.accessFlags & METHOD_FLAGS.ACC_FINAL) + && m.className == info.parentClassName) { + throw new OverrideFinalMethodError(info.name); + } + } + } + } + } + this.curTable = this.tables[this.curIdx] if (!this.curTable.has(key)) { const symbolNode: SymbolNode = { From d3de7310f75110ac0fec08ebbf2de5174ecb1cc1 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Mon, 7 Apr 2025 12:25:57 +0800 Subject: [PATCH 27/29] Modify the CSEC machine to be able to run conditional expressions Add a new instruction for the CSEC machine: cond --- src/ast/utils/astToString.ts | 8 +++-- src/ec-evaluator/instrCreator.ts | 17 ++++++++-- src/ec-evaluator/interpreter.ts | 45 ++++++++++++++++++++++---- src/ec-evaluator/types.ts | 9 +++++- src/ec-evaluator/utils.ts | 54 ++++++++++++++++++++++++++++++-- 5 files changed, 118 insertions(+), 15 deletions(-) diff --git a/src/ast/utils/astToString.ts b/src/ast/utils/astToString.ts index bf4f38d0..0deb3b6e 100644 --- a/src/ast/utils/astToString.ts +++ b/src/ast/utils/astToString.ts @@ -10,8 +10,8 @@ import { Literal, LocalVariableDeclarationStatement, MethodInvocation, - ReturnStatement, -} from '../types/blocks-and-statements'; + ReturnStatement, TernaryExpression +} from '../types/blocks-and-statements' import { ConstructorDeclaration, FieldDeclaration, @@ -114,6 +114,10 @@ export const astToString = (node: Node, indent: number = 0): string => { const bin = node as BinaryExpression; return `${astToString(bin.left)} ${bin.operator} ${astToString(bin.right)}`; + case "TernaryExpression": + const ter = node as TernaryExpression; + return `${astToString(ter.condition)} ? ${astToString(ter.consequent)} : ${astToString(ter.alternate)}`; + case "ExpressionName": const exp = node as ExpressionName; return exp.name; diff --git a/src/ec-evaluator/instrCreator.ts b/src/ec-evaluator/instrCreator.ts index d6aa6b37..e2a3287f 100644 --- a/src/ec-evaluator/instrCreator.ts +++ b/src/ec-evaluator/instrCreator.ts @@ -4,7 +4,7 @@ import { EnvNode } from "./components"; import { AssmtInstr, BinOpInstr, - Class, + Class, CondInstr, DerefInstr, EnvInstr, EvalVarInstr, @@ -18,8 +18,8 @@ import { ResOverloadInstr, ResOverrideInstr, ResTypeContInstr, - ResTypeInstr, -} from "./types"; + ResTypeInstr +} from './types' export const assmtInstr = ( srcNode: Node, @@ -154,3 +154,14 @@ export const resConOverloadInstr = ( srcNode, arity, }); + +export const condInstr = ( + trueExpr: Expression, + falseExpr: Expression, + srcNode: Node, +): CondInstr => ({ + instrType: InstrType.COND, + trueExpr, + falseExpr, + srcNode +}); diff --git a/src/ec-evaluator/interpreter.ts b/src/ec-evaluator/interpreter.ts index 0eeb08e4..3bbd7967 100644 --- a/src/ec-evaluator/interpreter.ts +++ b/src/ec-evaluator/interpreter.ts @@ -1,6 +1,6 @@ import { cloneDeep } from "lodash"; -import { +import { Assignment, BinaryExpression, Block, @@ -13,10 +13,10 @@ import { LocalVariableDeclarationStatement, LocalVariableType, MethodInvocation, - ReturnStatement, + ReturnStatement, TernaryExpression, VariableDeclarator, - Void, -} from "../ast/types/blocks-and-statements"; + Void +} from '../ast/types/blocks-and-statements' import { ConstructorDeclaration, FieldDeclaration, @@ -58,8 +58,8 @@ import { ResConOverloadInstr, ResOverrideInstr, ResTypeContInstr, - StructType, -} from "./types"; + StructType, CondInstr +} from './types' import { defaultValues, evaluateBinaryExpression, @@ -413,6 +413,17 @@ const cmdEvaluators: { [type: string]: CmdEvaluator } = { control.push(command.left); }, + TernaryExpression: ( + command: TernaryExpression, + _environment: Environment, + control: Control, + _stash: Stash, + ) => { + control.push(instr.condInstr(command.consequent, command.alternate, command)); + control.push(command.condition); + }, + + [InstrType.POP]: ( _command: Instr, _environment: Environment, @@ -841,4 +852,26 @@ const cmdEvaluators: { [type: string]: CmdEvaluator } = { // No post-processing required for constructor. }, + + [InstrType.COND]: ( + command: CondInstr, + _environment: Environment, + control: Control, + stash: Stash, + ) => { + // Pop the condition result (assumed to be a Literal). + const conditionValue = stash.pop() as Literal; + + const isTruthy = (value: Literal): boolean => { + return (value.literalType.kind == 'BooleanLiteral' && value.literalType.value == 'true') + || (value.literalType.kind != 'BooleanLiteral' && Boolean(value.literalType.value)) + }; + + // Determine truthiness (you may need to adjust this to your language's rules). + if (isTruthy(conditionValue)) { + control.push(command.trueExpr); + } else { + control.push(command.falseExpr); + } + } }; diff --git a/src/ec-evaluator/types.ts b/src/ec-evaluator/types.ts index 3b7eb17a..3a62814e 100644 --- a/src/ec-evaluator/types.ts +++ b/src/ec-evaluator/types.ts @@ -40,6 +40,7 @@ export enum InstrType { RES_OVERLOAD = 'ResOverload', RES_OVERRIDE = 'ResOverride', RES_CON_OVERLOAD = 'ResConOverload', + COND = 'Cond' } interface BaseInstr { @@ -98,6 +99,11 @@ export interface ResInstr extends BaseInstr { name: string; } +export interface CondInstr extends BaseInstr { + trueExpr: Expression; + falseExpr: Expression; +} + export interface DerefInstr extends BaseInstr {} export type Instr = @@ -115,7 +121,8 @@ export type Instr = | ResTypeInstr | ResTypeContInstr | ResOverloadInstr - | ResConOverloadInstr; + | ResConOverloadInstr + | CondInstr; /** * Components diff --git a/src/ec-evaluator/utils.ts b/src/ec-evaluator/utils.ts index ac73751e..e3807728 100644 --- a/src/ec-evaluator/utils.ts +++ b/src/ec-evaluator/utils.ts @@ -1,13 +1,13 @@ import { Node } from "../ast/types/ast"; import { - BlockStatement, + BlockStatement, BooleanLiteral, DecimalIntegerLiteral, Expression, ExpressionStatement, Literal, MethodInvocation, - ReturnStatement, -} from "../ast/types/blocks-and-statements"; + ReturnStatement +} from '../ast/types/blocks-and-statements' import { ConstructorDeclaration, FieldDeclaration, @@ -143,6 +143,54 @@ export const evaluateBinaryExpression = (operator: string, left: Literal, right: value: String(Number(left.literalType.value) / Number(right.literalType.value)), } as DecimalIntegerLiteral, }; + case ">": + return { + kind: "Literal", + literalType: { + kind: "BooleanLiteral", + value: left.literalType.value > right.literalType.value ? 'true' : 'false' + } as BooleanLiteral + }; + case "<": + return { + kind: "Literal", + literalType: { + kind: "BooleanLiteral", + value: left.literalType.value < right.literalType.value ? 'true' : 'false' + } as BooleanLiteral + } + case ">=": + return { + kind: "Literal", + literalType: { + kind: "BooleanLiteral", + value: left.literalType.value >= right.literalType.value ? 'true' : 'false' + } as BooleanLiteral + } + case "<=": + return { + kind: "Literal", + literalType: { + kind: "BooleanLiteral", + value: left.literalType.value <= right.literalType.value ? 'true' : 'false' + } as BooleanLiteral + } + case "==": + return { + kind: "Literal", + literalType: { + kind: "BooleanLiteral", + value: left.literalType.value == right.literalType.value ? 'true' : 'false' + } as BooleanLiteral + } + case "!=": + return { + kind: "Literal", + literalType: { + kind: "BooleanLiteral", + value: left.literalType.value != right.literalType.value ? 'true' : 'false' + } as BooleanLiteral + } default /* case "%" */: return { kind: "Literal", From 06b2683d98d4f7b64d94ebfa4fe461ca155aaec4 Mon Sep 17 00:00:00 2001 From: Aprup Kale Date: Mon, 7 Apr 2025 17:19:26 +0800 Subject: [PATCH 28/29] Modify the CSEC machine to be able to run switch statements Add a new instruction for the CSEC machine: switch --- src/ast/utils/astToString.ts | 27 ++++++++++++++- src/ec-evaluator/instrCreator.ts | 15 +++++++-- src/ec-evaluator/interpreter.ts | 57 ++++++++++++++++++++++++++++++-- src/ec-evaluator/types.ts | 13 ++++++-- 4 files changed, 104 insertions(+), 8 deletions(-) diff --git a/src/ast/utils/astToString.ts b/src/ast/utils/astToString.ts index 0deb3b6e..39e007e9 100644 --- a/src/ast/utils/astToString.ts +++ b/src/ast/utils/astToString.ts @@ -10,7 +10,7 @@ import { Literal, LocalVariableDeclarationStatement, MethodInvocation, - ReturnStatement, TernaryExpression + ReturnStatement, SwitchStatement, TernaryExpression } from '../types/blocks-and-statements' import { ConstructorDeclaration, @@ -118,6 +118,31 @@ export const astToString = (node: Node, indent: number = 0): string => { const ter = node as TernaryExpression; return `${astToString(ter.condition)} ? ${astToString(ter.consequent)} : ${astToString(ter.alternate)}`; + case "BreakStatement": + return 'break' + + case 'SwitchStatement': + const sw = node as SwitchStatement; + let result = indentLine(indent, `switch (${astToString(sw.expression)}) {`) + "\n"; + for (const swCase of sw.cases) { + // Print each label on its own line. + for (const label of swCase.labels) { + if (label.kind === "CaseLabel") { + result += indentLine(indent + INDENT_SPACES, `case ${astToString(label.expression)}:`) + "\n"; + } else if (label.kind === "DefaultLabel") { + result += indentLine(indent + INDENT_SPACES, `default:`) + "\n"; + } + } + // Print the statements for this case, if any. + if (swCase.statements && swCase.statements.length > 0) { + for (const stmt of swCase.statements) { + result += newline(astToString(stmt, indent + INDENT_SPACES * 2)); + } + } + } + result += indentLine(indent, "}"); + return result; + case "ExpressionName": const exp = node as ExpressionName; return exp.name; diff --git a/src/ec-evaluator/instrCreator.ts b/src/ec-evaluator/instrCreator.ts index e2a3287f..ef7d4ba9 100644 --- a/src/ec-evaluator/instrCreator.ts +++ b/src/ec-evaluator/instrCreator.ts @@ -1,5 +1,5 @@ import { Node } from "../ast/types/ast"; -import { Expression } from "../ast/types/blocks-and-statements"; +import { Expression, SwitchCase } from '../ast/types/blocks-and-statements' import { EnvNode } from "./components"; import { AssmtInstr, @@ -18,7 +18,7 @@ import { ResOverloadInstr, ResOverrideInstr, ResTypeContInstr, - ResTypeInstr + ResTypeInstr, SwitchInstr } from './types' export const assmtInstr = ( @@ -165,3 +165,14 @@ export const condInstr = ( falseExpr, srcNode }); + +export const switchInstr = ( + cases: Array, + expr: Expression, + srcNode: Node +): SwitchInstr => ({ + instrType: InstrType.SWITCH, + cases, + expr, + srcNode, +}); diff --git a/src/ec-evaluator/interpreter.ts b/src/ec-evaluator/interpreter.ts index 3bbd7967..72090508 100644 --- a/src/ec-evaluator/interpreter.ts +++ b/src/ec-evaluator/interpreter.ts @@ -13,7 +13,7 @@ import { LocalVariableDeclarationStatement, LocalVariableType, MethodInvocation, - ReturnStatement, TernaryExpression, + ReturnStatement, SwitchCase, SwitchStatement, TernaryExpression, VariableDeclarator, Void } from '../ast/types/blocks-and-statements' @@ -58,7 +58,7 @@ import { ResConOverloadInstr, ResOverrideInstr, ResTypeContInstr, - StructType, CondInstr + StructType, CondInstr, SwitchInstr } from './types' import { defaultValues, @@ -423,6 +423,16 @@ const cmdEvaluators: { [type: string]: CmdEvaluator } = { control.push(command.condition); }, + SwitchStatement: ( + command: SwitchStatement, + _environment: Environment, + control: Control, + _stash: Stash, + ) => { + control.push(instr.switchInstr(command.cases, command.expression, command)); + control.push(command.expression); + }, + [InstrType.POP]: ( _command: Instr, @@ -873,5 +883,48 @@ const cmdEvaluators: { [type: string]: CmdEvaluator } = { } else { control.push(command.falseExpr); } + }, + + [InstrType.SWITCH]: ( + command: SwitchInstr, + _environment: Environment, + control: Control, + stash: Stash, + ) => { + // Pop the evaluated discriminant from the stash. + const discValue = stash.pop() as Literal; + + let matchedCase: SwitchCase | null = null; + + // Iterate over each switch case. + for (const swCase of command.cases) { + // Check all labels for this case. + for (const label of swCase.labels) { + if (label.kind === "CaseLabel") { + // Assume the case label's expression is a literal. + const caseLiteral = label.expression as Literal; + if (discValue.literalType.value === caseLiteral.literalType.value) { + matchedCase = swCase; + break; + } + } else if (label.kind === "DefaultLabel") { + // Save default case (only one default should exist). + matchedCase = swCase; + } + } + if (matchedCase) break; + } + + // Determine which case to use. + if (matchedCase) { + if (matchedCase && matchedCase.statements && matchedCase.statements.length > 0) { + // Push the statements in reverse order to the control stack. + for (let i = matchedCase.statements.length - 1; i >= 0; i--) { + if (matchedCase.statements[i].kind == "BreakStatement") + continue + control.push(matchedCase.statements[i]); + } + } + } } }; diff --git a/src/ec-evaluator/types.ts b/src/ec-evaluator/types.ts index 3a62814e..357e83ab 100644 --- a/src/ec-evaluator/types.ts +++ b/src/ec-evaluator/types.ts @@ -1,5 +1,5 @@ import { Node } from "../ast/types/ast"; -import { Expression, Literal, Void } from "../ast/types/blocks-and-statements"; +import { Expression, Literal, SwitchCase, Void } from '../ast/types/blocks-and-statements' import { ConstructorDeclaration, FieldDeclaration, @@ -40,7 +40,8 @@ export enum InstrType { RES_OVERLOAD = 'ResOverload', RES_OVERRIDE = 'ResOverride', RES_CON_OVERLOAD = 'ResConOverload', - COND = 'Cond' + COND = 'Cond', + SWITCH = 'Switch', } interface BaseInstr { @@ -104,6 +105,11 @@ export interface CondInstr extends BaseInstr { falseExpr: Expression; } +export interface SwitchInstr extends BaseInstr { + cases: Array; + expr: Expression; +} + export interface DerefInstr extends BaseInstr {} export type Instr = @@ -122,7 +128,8 @@ export type Instr = | ResTypeContInstr | ResOverloadInstr | ResConOverloadInstr - | CondInstr; + | CondInstr + | SwitchInstr; /** * Components From a29402b4f3ddf5c194e372111885eb9b6eea3e06 Mon Sep 17 00:00:00 2001 From: AprupKale Date: Tue, 8 Apr 2025 11:00:15 +0800 Subject: [PATCH 29/29] Modify the CSEC machine to be able to run the break statement. Fix switch statements fallthrough logic and evaluation in the CSEC machine --- src/ec-evaluator/instrCreator.ts | 10 +++---- src/ec-evaluator/interpreter.ts | 49 ++++++++++++++++++++++---------- src/ec-evaluator/types.ts | 6 ++-- 3 files changed, 42 insertions(+), 23 deletions(-) diff --git a/src/ec-evaluator/instrCreator.ts b/src/ec-evaluator/instrCreator.ts index ef7d4ba9..a972b4bf 100644 --- a/src/ec-evaluator/instrCreator.ts +++ b/src/ec-evaluator/instrCreator.ts @@ -3,8 +3,8 @@ import { Expression, SwitchCase } from '../ast/types/blocks-and-statements' import { EnvNode } from "./components"; import { AssmtInstr, - BinOpInstr, - Class, CondInstr, + BinOpInstr, BranchInstr, + Class, DerefInstr, EnvInstr, EvalVarInstr, @@ -155,12 +155,12 @@ export const resConOverloadInstr = ( arity, }); -export const condInstr = ( +export const branchInstr = ( trueExpr: Expression, falseExpr: Expression, srcNode: Node, -): CondInstr => ({ - instrType: InstrType.COND, +): BranchInstr => ({ + instrType: InstrType.BRANCH, trueExpr, falseExpr, srcNode diff --git a/src/ec-evaluator/interpreter.ts b/src/ec-evaluator/interpreter.ts index 72090508..cea3f6ab 100644 --- a/src/ec-evaluator/interpreter.ts +++ b/src/ec-evaluator/interpreter.ts @@ -3,7 +3,7 @@ import { cloneDeep } from "lodash"; import { Assignment, BinaryExpression, - Block, + Block, BreakStatement, ClassInstanceCreationExpression, ExplicitConstructorInvocation, Expression, @@ -58,7 +58,7 @@ import { ResConOverloadInstr, ResOverrideInstr, ResTypeContInstr, - StructType, CondInstr, SwitchInstr + StructType, BranchInstr, SwitchInstr } from './types' import { defaultValues, @@ -419,7 +419,7 @@ const cmdEvaluators: { [type: string]: CmdEvaluator } = { control: Control, _stash: Stash, ) => { - control.push(instr.condInstr(command.consequent, command.alternate, command)); + control.push(instr.branchInstr(command.consequent, command.alternate, command)); control.push(command.condition); }, @@ -429,10 +429,23 @@ const cmdEvaluators: { [type: string]: CmdEvaluator } = { control: Control, _stash: Stash, ) => { + control.push(instr.markerInstr(command)); control.push(instr.switchInstr(command.cases, command.expression, command)); control.push(command.expression); }, + BreakStatement: ( + _command: BreakStatement, + _environment: Environment, + control: Control, + _stash: Stash, + ) => { + while ((control.peek() as Instr).instrType != InstrType.MARKER) { + control.pop(); + } + + control.pop(); // pop the marker + }, [InstrType.POP]: ( _command: Instr, @@ -863,8 +876,8 @@ const cmdEvaluators: { [type: string]: CmdEvaluator } = { // No post-processing required for constructor. }, - [InstrType.COND]: ( - command: CondInstr, + [InstrType.BRANCH]: ( + command: BranchInstr, _environment: Environment, control: Control, stash: Stash, @@ -895,9 +908,11 @@ const cmdEvaluators: { [type: string]: CmdEvaluator } = { const discValue = stash.pop() as Literal; let matchedCase: SwitchCase | null = null; + let matchedIndex = -1; // Iterate over each switch case. - for (const swCase of command.cases) { + for (let i = 0; i < command.cases.length; i++) { + const swCase = command.cases[i]; // Check all labels for this case. for (const label of swCase.labels) { if (label.kind === "CaseLabel") { @@ -905,24 +920,28 @@ const cmdEvaluators: { [type: string]: CmdEvaluator } = { const caseLiteral = label.expression as Literal; if (discValue.literalType.value === caseLiteral.literalType.value) { matchedCase = swCase; + matchedIndex = i; break; } - } else if (label.kind === "DefaultLabel") { + } else if (label.kind === "DefaultLabel" && !matchedCase) { // Save default case (only one default should exist). matchedCase = swCase; + matchedIndex = i; } } - if (matchedCase) break; } - // Determine which case to use. - if (matchedCase) { - if (matchedCase && matchedCase.statements && matchedCase.statements.length > 0) { + if (!matchedCase) { + return // do nothing if no matching case found. + } + + for (let i = command.cases.length; i >= matchedIndex; i--) { + const swCase = command.cases[i]; + + if (swCase && swCase.statements && swCase.statements.length > 0) { // Push the statements in reverse order to the control stack. - for (let i = matchedCase.statements.length - 1; i >= 0; i--) { - if (matchedCase.statements[i].kind == "BreakStatement") - continue - control.push(matchedCase.statements[i]); + for (let j = swCase.statements.length - 1; j >= 0; j--) { + control.push(swCase.statements[j]); } } } diff --git a/src/ec-evaluator/types.ts b/src/ec-evaluator/types.ts index 357e83ab..29d48c92 100644 --- a/src/ec-evaluator/types.ts +++ b/src/ec-evaluator/types.ts @@ -40,7 +40,7 @@ export enum InstrType { RES_OVERLOAD = 'ResOverload', RES_OVERRIDE = 'ResOverride', RES_CON_OVERLOAD = 'ResConOverload', - COND = 'Cond', + BRANCH = 'Branch', SWITCH = 'Switch', } @@ -100,7 +100,7 @@ export interface ResInstr extends BaseInstr { name: string; } -export interface CondInstr extends BaseInstr { +export interface BranchInstr extends BaseInstr { trueExpr: Expression; falseExpr: Expression; } @@ -128,7 +128,7 @@ export type Instr = | ResTypeContInstr | ResOverloadInstr | ResConOverloadInstr - | CondInstr + | BranchInstr | SwitchInstr; /**