User:Kmcguire/JavaClassLoadAndExecute

From OSDev.wiki
Revision as of 23:53, 7 May 2012 by Pancakes (talk | contribs) (unused fields in javaReadClass)
Jump to navigation Jump to search

You wanted to base your operating system with Java execution, but you were never quite sure how to get started? Well, here is some code (better here than sitting on my hard-disk) that can help give you an idea on how to get started.

The code is by no means complete. But, you can use it as a reference to sort of guide you through some potentially confusing documentation. There is no type checking for method calls. I used a simple trick to determine how many arguments are being passed to the method from the stack. You just walk backwards until you find an object reference then you reverse your list.

The most interesting part may likely be the actual loading of the class file which you can find below the method I used.

It can execute some basic operations so far. There is no java.lang! I like to use Python to model out the entire program which lets me handle some of the pitfalls and traps. It gives me a good picture of what I need to do in C/C++.

Here are some resources which will allow you to implement a full virtual machine:

http://murrayc.com/learning/java/java_classfileformat.shtml
http://en.wikipedia.org/wiki/Java_bytecode_instruction_listings
http://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html#jvms-4.7.3

Like I said the code below is sort of a guide so you can kind of see what is important to implement. That will help keep you going by being able to run your test files while you build it. For instance most of the constant pool types are not going to be used but by looking at just the specification it is hard to tell.

Also, most instructions are not used. So you can pretty much work through it as you go by adding in some catch when it finds an unknown opcode.

#!/usr/bin/python3.1
import os
import sys
import struct
import pprint

def funpack(fd, fmt):
	sz = struct.calcsize(fmt)
	data = fd.read(sz)
	return struct.unpack_from(fmt, data)

def javaReadClass(fd):
	hdr = {}
	magic, vermin, vermaj = funpack(fd, '>IHH')
	print('magic:%x vermin:%x vermaj:%x' % (magic, vermin, vermaj))
	constPoolCnt = funpack(fd, '>H')[0]
	# =========== const pool =========
	constPool = {}
	hdr['vermin'] = vermin
	hdr['vermaj'] = vermaj
	hdr['constPool'] = constPool
	x = 0
	while x < constPoolCnt - 1:
		x = x + 1
		tag = fd.read(1)
		# method ref
		if tag[0] == 10:
			e = {}
			e['name_index'], e['descriptor_index'] = funpack(fd, '>HH')
			e['type'] = 10
			constPool[x] = e
			continue
		# classinfo
		if tag[0] == 7:
			e = {}
			e['name_index'] = funpack(fd, '>H')[0]
			e['type'] = 7
			constPool[x] = e
			continue
		# utf8
		if tag[0] == 1:
			sz = funpack(fd, '>H')[0]
			data = fd.read(sz)
			e = {}
			e['type'] = 1
			e['value'] = data 
			constPool[x] = e
			continue
		# nameandtype
		if tag[0] == 12:
			e = {}
			e['name_index'], e['descriptor_index'] = funpack(fd, '>HH')
			e['type'] = 12
			constPool[x] = e
			continue
		# fieldref
		if tag[0] == 9:
			e = {}
			e['classIndex'], e['nameAndTypeIndex'] = funpack(fd, '>HH')
			e['type'] = 9
			constPool[x] = e
			continue
		raise Exception('unknown tag %s' % tag[0])
	# -------------------------
	classAccessFlags, thisClass, superClass, ifaceCnt = funpack(fd, '>HHHH')
	print('accessFlags:%x' % classAccessFlags)
	cinfo = constPool[thisClass]
	hdr['name'] = constPool[cinfo['name_index']]['value']
	cinfo = constPool[superClass]
	hdr['super'] = constPool[cinfo['name_index']]['value']
	# ========= interfaces =========
	ifaces = {}
	hdr['ifaces'] = ifaces
	x = 0
	while x < ifaceCnt:
		iface = funpack(fd, '>H')
		cinfo = constPool['classinfo'][iface]
		ifaces[x] = cinfo
		x = x + 1
	# -------------------------------
	# ========= fields ==============
	fieldsCnt = funpack(fd, '>H')[0]
	fields = {}
	hdr['fields'] = fields
	x = 0
	while x < fieldsCnt:
		field = {}
		fields[x + 1] = field
		field['accessFlags'], field['name_index'], \
			field['descriptor_index'], field['attrCnt'] = \
				funpack(fd, '>HHHH')
		field['name'] = constPool[field['name_index']]['value']
		x = x + 1
	# -------------------------------
	# ========= methods ========
	methCnt = funpack(fd, '>H')[0]
	methods = {}
	hdr['methods'] = methods
	x = 0
	while x < methCnt:
		method = {}
		methods[x] = method
		method['access_flags'], method['name_index'], \
			method['descriptor_index'], method['attrCnt'] = \
				funpack(fd, '>HHHH')
		attrCnt = method['attrCnt']
		attrs = []
		method['attrs'] = attrs
		y = 0
		# ========== attributes ==========
		while y < attrCnt:
			attr = {}
			attrs.append(attr)
			attr['name_index'], attr['length'] = funpack(fd, '>HI')
			attr['name'] = constPool[attr['name_index']]['value']
			attr['info'] = fd.read(attr['length'])
			y = y + 1
		# --------------------------------
		x = x + 1
	# -------------------------------
	# ========== attributes =========
	attrCnt = funpack(fd, '>H')[0]
	attrs = []
	x = 0
	while x < attrCnt:
		attr = {}
		attrs.append(attr)
		attr['name_index'], attr['length'] = funpack(fd, '>HI')
		attr['name'] = constPool[attr['name_index']]['value']
		attr['info'] = fd.read(attr['length'])
		x = x + 1
	hdr['attrs'] = attrs
	# ======== resolve some data =======
	__methods = {}
	for k in hdr['methods']:
		meth = hdr['methods'][k]
		meth['name'] = hdr['constPool'][meth['name_index']]['value']
		meth['descriptor'] = hdr['constPool'][meth['descriptor_index']]['value']
		__methods[meth['name']] = meth
	hdr['methods'] = __methods
	return hdr

def javaGetClassMethodAndCode(jclass, methName):
	for mk in jclass['methods']:
		method = jclass['methods'][mk]
		if mk == methName:
			# look for code attribute
			for attr in method['attrs']:
				if attr['name'] == b'Code':
					# found code attribute
					return (method, attr['info'])
			
	raise Exception('not found %s::%s' % (jclass['name'], methName))
	return False

def mkhexstr(data):
	h = []
	for d in data:
		h.append('%02x ' % d)
	return ''.join(h)

TYPE_UNK = 0
TYPE_OBJ = 1
TYPE_INT = 2
TYPE_SHORT = 3
TYPE_LONG = 4
TYPE_NULL = 5
TYPE_FLOAT = 6
TYPE_DOUBLE = 7

dbgmap = {
	TYPE_UNK:		'UNKNOWN',
	TYPE_OBJ:		'OBJECT',
	TYPE_INT:		'INT',
	TYPE_SHORT:		'SHORT',
	TYPE_LONG:		'LONG',
	TYPE_NULL:		'NULL',
	TYPE_FLOAT:		'FLOAT',
	TYPE_DOUBLE:	'DOUBLE',
}

class Obj:
	fields = None
	methods = None
	jclass = None

'''
	Variant structure with easier printing support for debugging.
'''
class Var:
	btype = 0
	value = None
	def __repr__(self):
		if self.btype == TYPE_OBJ:
			return '<OBJECT fields=%s>' % self.value.fields
		else:
			return '<%s:%s>' % (dbgmap[self.btype], self.value)

'''
	The stack contains variant like objects. We store
	basic type information on each stack item. 
'''
def var_mk_unk():
	var = Var()
	var.btype = TYPE_UNK
	return var
def var_mk_obj(obj):
	var = Var()
	var.btype = TYPE_OBJ;
	var.value = obj
	return var
def var_mk_float(f):
	var = Var()
	var.btype = TYPE_FLOAT
	var.value = f
	return var
def var_mk_null():
	var = Var()
	var.btype = TYPE_NULL;
	return var
def var_mk_long(l):
	var = Var()
	var.btype = TYPE_LONG
	var.value = l
	return var
def var_mk_int(i):
	var = Var()
	var.btype = TYPE_INT
	var.value = i
	return var

'''
	(1) create object structure
	(2) find specified class in bundle via textual classname
	(3) create fields in object
	(4) create methods in object (**needed??**)
	(5) push object reference to local_0
	(6) execute <init> method of objet
	(7) add object to system instanced objects list
	(8) return the java object
'''
def javaObjectInstance(jsys, jclassname):
	jobj = Obj()
	# look through bundle and find jclassname
	jclass = jsys.bundle[jclassname]
	
	# initialize fields 
	x = 0
	fields = {}
	for field in jclass['fields']:
		fields[x] = var_mk_unk()
	methods = {}
	for method in jclass['methods']:
		methods[method] = jclass['methods'][method]
	jobj.methods = methods
	jobj.fields = fields
	jobj.jclass = jclass
	jobj.name = jclassname
	# need to execute initialization method
	# <init>
	local = {}
	local[0] = var_mk_obj(jobj)
	javaExecuteClassMethod(jsys, jobj, jclass, b'<init>', local)
	
	# make system ref obj created
	jsys.objects.append(jobj)
	return jobj 

'''
	(1) get the code for the method in the javaClass
	(2) execute each opcode
'''
def javaExecuteClassMethod(jsys, jobj, jclass, methName, local):
	method, code = javaGetClassMethodAndCode(jclass, methName)
	print(method)
	
	print('\033[31mjavaExecuteClassMethod(class:%s method:%s\033[37m' % (jobj.name, methName))
	print(mkhexstr(code))
	# execution stack for method
	stack = []

	x = 0
	csz = len(code)
	while x < csz:
		opcode = code[x]
		print('===%02x===' % opcode)
		print(stack)
		# NOP
		if opcode == 0x00:
			x = x + 2
			continue
		# astore
		if opcode == 0x3a:
			ndx = code[x+1]
			local[ndx] = stack.pop(-1)
			x = x + 1
			continue
		# astore_0
		if opcode == 0x4b:
			local[0] = stack.pop(-1)
			x = x + 1
			continue
		# astore_1
		if opcode == 0x4c:
			local[1] = stack.pop(-1)
			x = x + 1
			continue
		# astore_2
		if opcode == 0x4d:
			local[2] = stack.pop(-1)
			x = x + 1
			continue
		# astore_3
		if opcode == 0x4e:
			local[3] = stack.pop(-1)
			x = x + 1
			continue
		# iload
		if opcode == 0x15:
			ndx = code[x+1]
			stack.append(local[ndx])
			x = x + 2
			continue
		# iload_0
		if opcode == 0x1a:
			stack.append(local[0])
			x = x + 1
			continue
		# iload_1
		if opcode == 0x1b:
			stack.append(local[1])
			x = x + 1
			continue
		# iload_2
		if opcode == 0x1c:
			stack.append(local[2])
			x = x + 1
			continue
		# iload_3
		if opcode == 0x1d:
			stack.append(local[3])
			x = x + 1
			continue
		# lstore
		if opcode == 0x37:
			ndx = code[x+1]
			local[ndx] = stack.pop(-1)
			x = x + 2
			continue
		# lstore_0
		if opcode == 0x3f:
			local[0] = stack.pop(-1)
			x = x + 1
			continue
		# lstore_1
		if opcode == 0x40:
			local[1] = stack.pop(-1)
			x = x + 1
			continue
		# lstore_2
		if opcode == 0x41:
			local[2] = stack.pop(-1)
			x = x + 1
			continue
		# lstore_3
		if opcode == 0x42:
			local[3] = stack.pop(-1)
			x = x + 1
			continue
		# aconst_null: push null onto the stack
		if opcode == 0x01:
			stack.append(var_mk_null())
			x = x + 1
			continue
		# lconst_0: push long(0) onto the stack
		if opcode == 0x09:
			stack.append(var_mk_long(0))
			x = x + 1
			continue
		# iconst_5: push long(5) onto the stack
		if opcode == 0x08:
			stack.append(var_mk_long(5))
			x = x + 1
			continue
		# putstatic: set static field to value in the class
		if opcode == 0xb3:
			iby1 = code[x+1]
			iby2 = code[x+2]
			frefndx = iby1 << 8 + iby2
			print('0xb3', frefndx)
			raise Exception('not implemented')
			x = x + 3
			continue
		# iconst_m1
		if opcode == 0x02:
			stack.append(var_mk_int(-1))
			x = x + 1
			continue
		# iconst_0
		if opcode == 0x03:
			stack.append(var_mk_int(0))
			x = x + 1
			continue
		# iconst_1
		if opcode == 0x04:
			stack.append(var_mk_int(1))
			x = x + 1
			continue
		# iconst_2
		if opcode == 0x05:
			stack.append(var_mk_int(2))
			x = x + 1
			continue
		# iconst_3
		if opcode == 0x06:
			stack.append(var_mk_int(3))
			x = x + 1
			continue
		# iconst_4
		if opcode == 0x07:
			stack.append(var_mk_int(4))
			x = x + 1
			continue
		# iconst_5
		if opcode == 0x08:
			stack.append(var_mk_int(5))
			x = x + 1
		# fconst_0
		if opcode == 0x0b:
			stack.append(var_mk_float(0.0))
			x = x + 1
			continue
		# fconst_1
		if opcode == 0x0c:
			stack.append(var_mk_float(1.0))
			x = x + 1
			continue
		# fconst_2
		if opcode == 0x0d:
			stack.append(var_mk_float(2.0))
			x = x + 1
			continue
		# aload_0: load a ref onto the stack from locvar 0
		if opcode == 0x2a:
			stack.append(local[0])
			x = x + 1
			continue
		# aload_0: load a ref onto the stack from locvar 0
		if opcode == 0x2b:
			stack.append(local[1])
			x = x + 1
			continue
		# aload_0: load a ref onto the stack from locvar 0
		if opcode == 0x2c:
			stack.append(local[2])
			x = x + 1
			continue
		# aload_0: load a ref onto the stack from locvar 0
		if opcode == 0x2d:
			stack.append(local[3])
			x = x + 1
			continue
		# putfield: set field to value in object
		if opcode == 0xb5:
			ndx = code[x+1] << 8 | code[x+2]
			#print('stack:%s' % stack)
			value = stack.pop(-1)
			objref = stack.pop(-1)
			print('objref:%s' % objref)
			if objref.btype != TYPE_OBJ:
				print('attempt to set field on non-object')
				return 0
			objref.value.fields[ndx] = value
			print('putfield:%x\n' % ndx)
			x = x + 3
			continue
		# return void from method
		if opcode == 0xb1:
			return var_mk_null()
		# getfield
		if opcode == 0xb4:
			ndx = code[x+1] << 8 | code[x+2]
			objref = stack.pop(-1)
			value = objref.value.fields[ndx]
			stack.append(value)
			x = x + 3
			continue
		# iadd
		if opcode == 0x60:
			value1 = stack.pop(-1)
			value2 = stack.pop(-1)
			value3 = Var()
			value3.btype = TYPE_INT
			value3.value = value1.value + value2.value
			stack.append(value3)
			x = x + 1
			continue
		# return integer from method
		if opcode == 0xac:
			value = stack.pop(-1)
			return value
		# bipush: push a byte onto the stack as
		#         an integer value
		if opcode == 0x10:
			byte = code[x+1]
			var = Var()
			var.btype = TYPE_INT
			var.value = int(byte)
			stack.append(var)
			x = x + 2
			continue
		# idiv: divide two integers
		if opcode == 0x6c:
			value1 = stack.pop(-1)
			value2 = stack.pop(-1)
			value3 = Var()
			value3.btype = TYPE_INT;
			value3.value = int(value2.value / value1.value)
			stack.append(value3)
			x = x + 1
			continue
		# isub
		if opcode == 0x64:
			value1 = stack.pop(-1)
			value2 = stack.pop(-1)
			value3 = Var()
			value3.btype = TYPE_INT;
			value3.value = int(value2.value - value1.value)
			stack.append(value3)
			x = x + 1
			continue
		# imul: multiply two integers
		if opcode == 0x68:
			value1 = stack.pop(-1)
			value2 = stack.pop(-1)
			value3 = Var()
			value3.btype = TYPE_INT;
			value3.value = int(value2.value * value1.value)
			stack.append(value3)
			x = x + 1
			continue
		# invokespecial
		if opcode == 0xb7:
			ndx = code[x+1] << 8 | code[x+2]
			# nameandtype const
			constPool = jclass['constPool']
			m = constPool[ndx]
			name = constPool[m['name_index']]
			desc = constPool[m['descriptor_index']]
			print('..name:%s desc:%s' % (name, desc))
			_name = constPool[name['name_index']]['value']
			name = constPool[desc['name_index']]['value']
			desc = constPool[desc['descriptor_index']]['value']
			
			# just ignore built-in stuff for now until
			# i can try to get it actually created when
			# get a chance
			if _name == b'java/lang/Object':
				print('\033[32mattempted call for %s:%s' % (name, desc))
				x = x + 3
				continue
			
			print('_name:%s name:%s descriptor:%s' % (_name, name, desc))
			
			# handles invokation with arguments
			local = {}
			y = 0
			while stack[-1].btype != TYPE_OBJ:
				local[y] = stack.pop(-1)
				y = y + 1
			objref = stack.pop(-1)
			# reverse arguments
			sz = len(local)
			nlocal = {}
			for k in local:
				nlocal[sz - k] = local[k]
			nlocal[0] = objref

			local = {}
			ret = javaExecuteClassMethod(jsys, objref.value, objref.value.jclass, name, nlocal)
			x = x + 3
			continue
		# dup: duplicate the value on top stack
		if opcode == 0x59:
			stack.append(stack[-1])
			x = x + 1
			continue
		# new
		if opcode == 0xbb:
			ndx = code[x+1] << 8 | code[x+2]
			constPool = jclass['constPool']
			m = constPool[ndx]
			m = constPool[m['name_index']]['value']
			# find apple in our bundle or try loading
			# it from disk in the relative directory
			if m not in jsys.bundle:
				raise Exception('Not In Bundle! [not-implemented]')
			# new java object
			njobj = javaObjectInstance(jsys, m)
			stack.append(var_mk_obj(njobj))
			x = x + 3
			continue
		if opcode == 0xb6:
			#print(stack)
			ndx = code[x+1] << 8 | code[x+2]
			# need to get virtual method and find out
			# how many arguments it needs
			
			# pop from stack until we reach an object type
			local = {}
			y = 0
			while stack[-1].btype != TYPE_OBJ:
				local[y] = stack.pop(-1)
				y = y + 1
			objref = stack.pop(-1)
			# reverse arguments
			sz = len(local)
			nlocal = {}
			for k in local:
				nlocal[sz - k] = local[k]
			nlocal[0] = objref
			
			#print('local:%s' % local)
			#print('nlocal:%s' % nlocal)
			#print('objref:%s' % objref)
			
			constPool = jclass['constPool']
			m = constPool[ndx]
			print('m:%s' % m)
			'''
				If I follow desc out I get the name of the
				method we are calling and its type info.
				
				But, if you follow name you end up with
				the name of the current class.
			'''
			name = constPool[m['name_index']]
			desc = constPool[m['descriptor_index']]
			print('..name:%s desc:%s' % (name, desc))
			_name = constPool[name['name_index']]['value']
			name = constPool[desc['name_index']]['value']
			desc = constPool[desc['descriptor_index']]['value']
			
			print(_name, name, desc)
			
			# objref is the actual instance of an object
			# name identifies the virtual method
			# desc describes the return type and arguments
			ret = javaExecuteClassMethod(jsys, objref.value, objref.value.jclass, name, nlocal)
			print('~~~ ret from call ~~~')
			print('@@', ret)
			stack.append(ret)
			x = x + 3
			continue
			
			
		print('stack:%s' % stack)
		raise Exception('opcode not understood %x' % opcode)

class JavaSystem:
	bundle = {}
	objects = []

jsys = JavaSystem()

fd = open('Apple.class', 'rb')
jclass = javaReadClass(fd)
jsys.bundle[jclass['name']] = jclass
fd.close()

fd = open('Test.class', 'rb')
jclass = javaReadClass(fd)
jsys.bundle[jclass['name']] = jclass
fd.close()

jobj = javaObjectInstance(jsys, b'Test')

local = {}
local[0] = var_mk_obj(jobj)
ret = javaExecuteClassMethod(jsys, jobj, jclass, b'main', local)
print('------ return -------')
pprint.pprint(ret)