diff --git a/typed_python/SerializationContext.py b/typed_python/SerializationContext.py index b110be90..832bbb41 100644 --- a/typed_python/SerializationContext.py +++ b/typed_python/SerializationContext.py @@ -31,11 +31,31 @@ import types import traceback import logging +import numpy +import pickle _badModuleCache = set() +def pickledByStr(module_name: str, name: str) -> None: + """Generate the object given the module_name and name. + + This mimics pickle's behavior when given a string from __reduce__. The + string is interpreted as the name of a global variable, and pickle.whichmodules + is used to search the module namespace, generating module_name. + + Note that 'name' might contain '.' inside of it, since its a 'local name'. + """ + module = importlib.import_module(module_name) + + instance = module + for subName in name.split('.'): + instance = getattr(instance, subName) + + return instance + + def createFunctionWithLocalsAndGlobals(code, globals): if globals is None: globals = {} @@ -708,26 +728,30 @@ def walkCodeObject(code): return (createFunctionWithLocalsAndGlobals, args, representation) if not isinstance(inst, type) and hasattr(type(inst), '__reduce_ex__'): - res = inst.__reduce_ex__(4) + if isinstance(inst, numpy.ufunc): + res = inst.__name__ + else: + res = inst.__reduce_ex__(4) - # pickle supports a protocol where __reduce__ can return a string - # giving a global name. We'll already find that separately, so we - # don't want to handle it here. We ought to look at this in more detail - # however + # mimic pickle's behaviour when a string is received. if isinstance(res, str): - return None + name_tuple = (inst, res) + module_name = pickle.whichmodule(*name_tuple) + res = (pickledByStr, (module_name, res,), pickledByStr) return res if not isinstance(inst, type) and hasattr(type(inst), '__reduce__'): - res = inst.__reduce__() + if isinstance(inst, numpy.ufunc): + res = inst.__name__ + else: + res = inst.__reduce() - # pickle supports a protocol where __reduce__ can return a string - # giving a global name. We'll already find that separately, so we - # don't want to handle it here. We ought to look at this in more detail - # however + # mimic pickle's behaviour when a string is received. if isinstance(res, str): - return None + name_tuple = (inst, res) + module_name = pickle.whichmodule(*name_tuple) + res = (pickledByStr, (module_name, res,), pickledByStr) return res @@ -736,6 +760,9 @@ def walkCodeObject(code): def setInstanceStateFromRepresentation( self, instance, representation=None, itemIt=None, kvPairIt=None, setStateFun=None ): + if representation is pickledByStr: + return + if representation is reconstructTypeFunctionType: return diff --git a/typed_python/compiler/binary_shared_object.py b/typed_python/compiler/binary_shared_object.py index 90089215..5c2e5765 100644 --- a/typed_python/compiler/binary_shared_object.py +++ b/typed_python/compiler/binary_shared_object.py @@ -26,8 +26,8 @@ class LoadedBinarySharedObject(LoadedModule): - def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariableDefinitions): - super().__init__(functionPointers, globalVariableDefinitions) + def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlobalVariableDefinitions): + super().__init__(functionPointers, serializedGlobalVariableDefinitions) self.binarySharedObject = binarySharedObject self.diskPath = diskPath @@ -36,15 +36,17 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariabl class BinarySharedObject: """Models a shared object library (.so) loadable on linux systems.""" - def __init__(self, binaryForm, functionTypes, globalVariableDefinitions): + def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies): """ Args: - binaryForm - a bytes object containing the actual compiled code for the module - globalVariableDefinitions - a map from name to GlobalVariableDefinition + binaryForm: a bytes object containing the actual compiled code for the module + serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition + globalDependencies: a dict from function linkname to the list of global variables it depends on """ self.binaryForm = binaryForm self.functionTypes = functionTypes - self.globalVariableDefinitions = globalVariableDefinitions + self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions + self.globalDependencies = globalDependencies self.hash = sha_hash(binaryForm) @property @@ -52,14 +54,14 @@ def definedSymbols(self): return self.functionTypes.keys() @staticmethod - def fromDisk(path, globalVariableDefinitions, functionNameToType): + def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): with open(path, "rb") as f: binaryForm = f.read() - return BinarySharedObject(binaryForm, functionNameToType, globalVariableDefinitions) + return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) @staticmethod - def fromModule(module, globalVariableDefinitions, functionNameToType): + def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): target_triple = llvm.get_process_triple() target = llvm.Target.from_triple(target_triple) target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default') @@ -80,7 +82,7 @@ def fromModule(module, globalVariableDefinitions, functionNameToType): ) with open(os.path.join(tf, "module.so"), "rb") as so_file: - return BinarySharedObject(so_file.read(), functionNameToType, globalVariableDefinitions) + return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) def load(self, storageDir): """Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer""" @@ -127,8 +129,7 @@ def loadFromPath(self, modulePath): self, modulePath, functionPointers, - self.globalVariableDefinitions + self.serializedGlobalVariableDefinitions ) - loadedModule.linkGlobalVariables() return loadedModule diff --git a/typed_python/compiler/compiler_cache.py b/typed_python/compiler/compiler_cache.py index a093fc70..c3a193b8 100644 --- a/typed_python/compiler/compiler_cache.py +++ b/typed_python/compiler/compiler_cache.py @@ -15,9 +15,12 @@ import os import uuid import shutil -from typed_python.compiler.loaded_module import LoadedModule -from typed_python.compiler.binary_shared_object import BinarySharedObject +from typing import Optional, List + +from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject +from typed_python.compiler.directed_graph import DirectedGraph +from typed_python.compiler.typed_call_target import TypedCallTarget from typed_python.SerializationContext import SerializationContext from typed_python import Dict, ListOf @@ -46,154 +49,212 @@ class CompilerCache: when we first boot up, which could be slow. We could improve this substantially by making it possible to determine if a given function is in the cache by organizing the manifests by, say, function name. + + Due to the potential for race conditions, we must distinguish between the following: + func_name - The identifier for the function, based on its identity hash. + link_name - The identifier for the specific realization of that function, which lives in a specific + cache module. """ def __init__(self, cacheDir): self.cacheDir = cacheDir ensureDirExists(cacheDir) - self.loadedModules = Dict(str, LoadedModule)() - self.nameToModuleHash = Dict(str, str)() - - self.modulesMarkedValid = set() - self.modulesMarkedInvalid = set() - + self.loadedBinarySharedObjects = Dict(str, LoadedBinarySharedObject)() + self.link_name_to_module_hash = Dict(str, str)() + self.moduleManifestsLoaded = set() + # link_names with an associated module in loadedBinarySharedObjects + self.targetsLoaded: Dict[str, TypedCallTarget] = {} + # the set of link_names for functions with linked and validated globals (i.e. ready to be run). + self.targetsValidated = set() + # link_name -> link_name + self.function_dependency_graph = DirectedGraph() + # dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions) + self.global_dependencies = Dict(str, ListOf(str))() + self.func_name_to_link_names = Dict(str, ListOf(str))() for moduleHash in os.listdir(self.cacheDir): if len(moduleHash) == 40: self.loadNameManifestFromStoredModuleByHash(moduleHash) - def hasSymbol(self, linkName): - return linkName in self.nameToModuleHash + def hasSymbol(self, func_name: str) -> bool: + """Returns true if there are any versions of `func_name` in the cache. - def markModuleHashInvalid(self, hashstr): - with open(os.path.join(self.cacheDir, hashstr, "marked_invalid"), "w"): - pass - - def loadForSymbol(self, linkName): - moduleHash = self.nameToModuleHash[linkName] + There may be multiple copies in different modules with different link_names. + """ + return any(link_name in self.link_name_to_module_hash for link_name in self.func_name_to_link_names.get(func_name, [])) - nameToTypedCallTarget = {} - nameToNativeFunctionType = {} + def getTarget(self, func_name: str) -> TypedCallTarget: + if not self.hasSymbol(func_name): + raise ValueError(f'symbol not found for func_name {func_name}') + link_name = self._select_link_name(func_name) + self.loadForSymbol(link_name) + return self.targetsLoaded[link_name] - if not self.loadModuleByHash(moduleHash, nameToTypedCallTarget, nameToNativeFunctionType): - return None + def _generate_link_name(self, func_name: str, module_hash: str) -> str: + return func_name + "." + module_hash - return nameToTypedCallTarget, nameToNativeFunctionType + def _select_link_name(self, func_name) -> str: + """choose a link name for a given func name. - def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFunctionType): + Currently we just choose the first available option. + Throws a KeyError if func_name isn't in the cache. + """ + link_name_candidates = self.func_name_to_link_names[func_name] + return link_name_candidates[0] + + def dependencies(self, link_name: str) -> Optional[List[str]]: + """Returns all the function names that `link_name` depends on""" + return list(self.function_dependency_graph.outgoing(link_name)) + + def loadForSymbol(self, linkName: str) -> None: + """Loads the whole module, and any dependant modules, into LoadedBinarySharedObjects""" + moduleHash = self.link_name_to_module_hash[linkName] + + self.loadModuleByHash(moduleHash) + + if linkName not in self.targetsValidated: + self.targetsValidated.add(linkName) + for dependant_func in self.dependencies(linkName): + self.loadForSymbol(dependant_func) + + globalsToLink = self.global_dependencies.get(linkName, []) + if globalsToLink: + definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x] + for x in globalsToLink + } + self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink) + if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink): + raise RuntimeError('failed to validate globals when loading:', linkName) + + def loadModuleByHash(self, moduleHash: str) -> None: """Load a module by name. - As we load, place all the newly imported typed call targets into - 'nameToTypedCallTarget' so that the rest of the system knows what functions - have been uncovered. + Add the module contents to targetsLoaded, generate a LoadedBinarySharedObject, + and update the function and global dependency graphs. """ - if moduleHash in self.loadedModules: - return True + if moduleHash in self.loadedBinarySharedObjects: + return targetDir = os.path.join(self.cacheDir, moduleHash) - try: - with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f: - callTargets = SerializationContext().deserialize(f.read()) + # TODO (Will) - store these names as module consts, use one .dat only + with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f: + # func_name -> typedcalltarget + callTargets = SerializationContext().deserialize(f.read()) + + with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f: + serializedGlobalVarDefs = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f: - globalVarDefs = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f: + functionNameToNativeType = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f: - functionNameToNativeType = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: + submodules = SerializationContext().deserialize(f.read(), ListOf(str)) - with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: - submodules = SerializationContext().deserialize(f.read(), ListOf(str)) - except Exception: - self.markModuleHashInvalid(moduleHash) - return False + with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f: + dependency_edgelist = SerializationContext().deserialize(f.read()) - if not LoadedModule.validateGlobalVariables(globalVarDefs): - self.markModuleHashInvalid(moduleHash) - return False + with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f: + globalDependencies = SerializationContext().deserialize(f.read()) # load the submodules first for submodule in submodules: - if not self.loadModuleByHash( - submodule, - nameToTypedCallTarget, - nameToNativeFunctionType - ): - return False + self.loadModuleByHash(submodule) modulePath = os.path.join(targetDir, "module.so") loaded = BinarySharedObject.fromDisk( modulePath, - globalVarDefs, - functionNameToNativeType + serializedGlobalVarDefs, + functionNameToNativeType, + globalDependencies ).loadFromPath(modulePath) - self.loadedModules[moduleHash] = loaded + self.loadedBinarySharedObjects[moduleHash] = loaded - nameToTypedCallTarget.update(callTargets) - nameToNativeFunctionType.update(functionNameToNativeType) + for func_name, callTarget in callTargets.items(): + link_name = self._generate_link_name(func_name, moduleHash) + assert link_name not in self.targetsLoaded + self.targetsLoaded[link_name] = callTarget - return True + link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()} - def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies): + assert not any(key in self.global_dependencies for key in link_name_global_dependencies) + + self.global_dependencies.update(link_name_global_dependencies) + # update the cache's dependency graph with our new edges. + for function_name, dependant_function_name in dependency_edgelist: + self.function_dependency_graph.addEdge(source=function_name, dest=dependant_function_name) + + def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, dependencyEdgelist): """Add new code to the compiler cache. + Args: - binarySharedObject - a BinarySharedObject containing the actual assembler - we've compiled - nameToTypedCallTarget - a dict from linkname to TypedCallTarget telling us - the formal python types for all the objects - linkDependencies - a set of linknames we depend on directly. + binarySharedObject: a BinarySharedObject containing the actual assembler + we've compiled. + nameToTypedCallTarget: a dict from func_name to TypedCallTarget telling us + the formal python types for all the objects. + linkDependencies: a set of func_names we depend on directly. (this becomes submodules) + dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the + module. + + TODO (Will): the notion of submodules/linkDependencies can be refactored out. """ - dependentHashes = set() + hashToUse = SerializationContext().sha_hash(str(uuid.uuid4())).hexdigest + + # the linkDependencies and dependencyEdgelist are in terms of func_name. + dependentHashes = set() for name in linkDependencies: - dependentHashes.add(self.nameToModuleHash[name]) + link_name = self._select_link_name(name) + dependentHashes.add(self.link_name_to_module_hash[link_name]) + + link_name_dependency_edgelist = [] + for source, dest in dependencyEdgelist: + assert source in binarySharedObject.definedSymbols + source_link_name = self._generate_link_name(source, hashToUse) + if dest in binarySharedObject.definedSymbols: + dest_link_name = self._generate_link_name(dest, hashToUse) + else: + dest_link_name = self._select_link_name(dest) + link_name_dependency_edgelist.append([source_link_name, dest_link_name]) - path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes) + path = self.writeModuleToDisk(binarySharedObject, hashToUse, nameToTypedCallTarget, dependentHashes, link_name_dependency_edgelist) - self.loadedModules[hashToUse] = ( + self.loadedBinarySharedObjects[hashToUse] = ( binarySharedObject.loadFromPath(os.path.join(path, "module.so")) ) - for n in binarySharedObject.definedSymbols: - self.nameToModuleHash[n] = hashToUse - - def loadNameManifestFromStoredModuleByHash(self, moduleHash): - if moduleHash in self.modulesMarkedValid: - return True - - targetDir = os.path.join(self.cacheDir, moduleHash) + for func_name in binarySharedObject.definedSymbols: + link_name = self._generate_link_name(func_name, hashToUse) + self.link_name_to_module_hash[link_name] = hashToUse + self.func_name_to_link_names.setdefault(func_name, []).append(link_name) - # ignore 'marked invalid' - if os.path.exists(os.path.join(targetDir, "marked_invalid")): - # just bail - don't try to read it now + # link & validate all globals for the new module + self.loadedBinarySharedObjects[hashToUse].linkGlobalVariables() + if not self.loadedBinarySharedObjects[hashToUse].validateGlobalVariables( + self.loadedBinarySharedObjects[hashToUse].serializedGlobalVariableDefinitions): + raise RuntimeError('failed to validate globals in new module:', hashToUse) - # for the moment, we don't try to clean up the cache, because - # we can't be sure that some process is not still reading the - # old files. - self.modulesMarkedInvalid.add(moduleHash) - return False + def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None: + if moduleHash in self.moduleManifestsLoaded: + return - with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: - submodules = SerializationContext().deserialize(f.read(), ListOf(str)) - - for subHash in submodules: - if not self.loadNameManifestFromStoredModuleByHash(subHash): - self.markModuleHashInvalid(subHash) - return False + targetDir = os.path.join(self.cacheDir, moduleHash) + # TODO (Will) the name_manifest module_hash is the same throughout so this doesn't need to be a dict. with open(os.path.join(targetDir, "name_manifest.dat"), "rb") as f: - self.nameToModuleHash.update( - SerializationContext().deserialize(f.read(), Dict(str, str)) - ) + func_name_to_module_hash = SerializationContext().deserialize(f.read(), Dict(str, str)) - self.modulesMarkedValid.add(moduleHash) + for func_name, module_hash in func_name_to_module_hash.items(): + link_name = self._generate_link_name(func_name, module_hash) + self.func_name_to_link_names.setdefault(func_name, []).append(link_name) + self.link_name_to_module_hash[link_name] = module_hash - return True + self.moduleManifestsLoaded.add(moduleHash) - def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules): + def writeModuleToDisk(self, binarySharedObject, hashToUse, nameToTypedCallTarget, submodules, dependencyEdgelist): """Write out a disk representation of this module. This includes writing both the shared object, a manifest of the function names @@ -207,7 +268,6 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule to interact with the compiler cache simultaneously without relying on individual file-level locking. """ - hashToUse = SerializationContext().sha_hash(str(uuid.uuid4())).hexdigest targetDir = os.path.join( self.cacheDir, @@ -236,21 +296,24 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule for sourceName in manifest: f.write(sourceName + "\n") - # write the type manifest with open(os.path.join(tempTargetDir, "type_manifest.dat"), "wb") as f: f.write(SerializationContext().serialize(nameToTypedCallTarget)) - # write the nativetype manifest with open(os.path.join(tempTargetDir, "native_type_manifest.dat"), "wb") as f: f.write(SerializationContext().serialize(binarySharedObject.functionTypes)) - # write the type manifest with open(os.path.join(tempTargetDir, "globals_manifest.dat"), "wb") as f: - f.write(SerializationContext().serialize(binarySharedObject.globalVariableDefinitions)) + f.write(SerializationContext().serialize(binarySharedObject.serializedGlobalVariableDefinitions)) with open(os.path.join(tempTargetDir, "submodules.dat"), "wb") as f: f.write(SerializationContext().serialize(ListOf(str)(submodules), ListOf(str))) + with open(os.path.join(tempTargetDir, "function_dependencies.dat"), "wb") as f: + f.write(SerializationContext().serialize(dependencyEdgelist)) + + with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f: + f.write(SerializationContext().serialize(binarySharedObject.globalDependencies)) + try: os.rename(tempTargetDir, targetDir) except IOError: @@ -259,14 +322,15 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule else: shutil.rmtree(tempTargetDir) - return targetDir, hashToUse + return targetDir - def function_pointer_by_name(self, linkName): - moduleHash = self.nameToModuleHash.get(linkName) + def function_pointer_by_name(self, func_name): + linkName = self._select_link_name(func_name) + moduleHash = self.link_name_to_module_hash.get(linkName) if moduleHash is None: raise Exception("Can't find a module for " + linkName) - if moduleHash not in self.loadedModules: + if moduleHash not in self.loadedBinarySharedObjects: self.loadForSymbol(linkName) - return self.loadedModules[moduleHash].functionPointers[linkName] + return self.loadedBinarySharedObjects[moduleHash].functionPointers[func_name] diff --git a/typed_python/compiler/compiler_cache_test.py b/typed_python/compiler/compiler_cache_test.py index 81ad2f12..639e929e 100644 --- a/typed_python/compiler/compiler_cache_test.py +++ b/typed_python/compiler/compiler_cache_test.py @@ -13,6 +13,7 @@ # limitations under the License. import tempfile +import threading import os import pytest from typed_python.test_util import evaluateExprInFreshProcess @@ -24,6 +25,47 @@ def f(x): """ +class MockDirectory: + def __init__(self, tree): + self.tree = tree + + @staticmethod + def fromPath(path): + tree = {} + + def populate(path, tree): + for relPath in os.listdir(path): + subPath = os.path.join(path, relPath) + + if os.path.isdir(subPath): + populate(os.path.join(path, subPath), tree.setdefault(relPath, {})) + elif os.path.isfile(subPath): + with open(subPath, "rb") as file: + tree[relPath] = file.read() + + populate(path, tree) + + return MockDirectory(tree) + + def dumpInto(self, path): + def populate(path, tree): + if isinstance(tree, dict): + if not os.path.isdir(path): + os.mkdir(path) + + for relPath, subtree in tree.items(): + subpath = os.path.join(path, relPath) + + populate(subpath, subtree) + elif isinstance(tree, bytes): + with open(path, "wb") as file: + file.write(tree) + else: + raise Exception(f"Can't handle {type(tree)} in the tree.") + + populate(path, self.tree) + + @pytest.mark.skipif('sys.platform=="darwin"') def test_compiler_cache_populates(): with tempfile.TemporaryDirectory() as compilerCacheDir: @@ -119,6 +161,7 @@ def test_compiler_cache_understands_type_changes(): VERSION1 = {'x.py': xmodule, 'y.py': ymodule} VERSION2 = {'x.py': xmodule.replace("1: 2", "1: 3"), 'y.py': ymodule} VERSION3 = {'x.py': xmodule.replace("int, int", "int, float").replace('1: 2', '1: 2.5'), 'y.py': ymodule} + VERSION4 = {'x.py': xmodule.replace("1: 2", "1: 4"), 'y.py': ymodule} assert '1: 3' in VERSION2['x.py'] @@ -134,6 +177,10 @@ def test_compiler_cache_understands_type_changes(): assert evaluateExprInFreshProcess(VERSION3, 'y.g(1)', compilerCacheDir) == 2.5 assert len(os.listdir(compilerCacheDir)) == 2 + # use the previously compiled module + assert evaluateExprInFreshProcess(VERSION4, 'y.g(1)', compilerCacheDir) == 4 + assert len(os.listdir(compilerCacheDir)) == 2 + @pytest.mark.skipif('sys.platform=="darwin"') def test_compiler_cache_handles_exceptions_properly(): @@ -256,6 +303,164 @@ def test_reference_existing_function_twice(): assert len(os.listdir(compilerCacheDir)) == 2 +@pytest.mark.skipif('sys.platform=="darwin"') +def test_can_compile_overlapping_code(): + common = "\n".join([ + "import time", + "import os", + + "t0 = time.time()", + "path = os.path.join(os.getenv('TP_COMPILER_CACHE'), 'check.txt')", + "while not os.path.exists(path):", + " time.sleep(.01)", + " assert time.time() - t0 < 2", + ]) + + xmodule1 = "\n".join([ + "import common", + "@Entrypoint", + "def f():", + " x = Dict(int, int)()", + " x[3] = 4", + " return x[3]" + ]) + + xmodule2 = "\n".join([ + "import common", + "@Entrypoint", + "def g():", + " x = Dict(int, int)()", + " x[3] = 5", + " return x[3]" + ]) + + xmodule3 = "\n".join([ + "from x1 import f", + "from x2 import g", + "@Entrypoint", + "def h():", + " return f() + g()" + ]) + + MODULES = {'common.py': common, 'x1.py': xmodule1, 'x2.py': xmodule2, 'x3.py': xmodule3} + + with tempfile.TemporaryDirectory() as compilerCacheDir: + # first, compile 'f' and 'g' in two separate processes. Because of the loop + # they will wait until we write out the file that lets them start compiling. Then + # the'll both compile something that has common code. + # we should be able to then use that common code without issue. + threads = [ + threading.Thread( + target=evaluateExprInFreshProcess, args=(MODULES, 'x1.f()', compilerCacheDir) + ), + threading.Thread( + target=evaluateExprInFreshProcess, args=(MODULES, 'x2.g()', compilerCacheDir) + ), + ] + for t in threads: + t.start() + + with open(os.path.join(compilerCacheDir, "check.txt"), 'w') as f: + f.write('start!') + + for t in threads: + t.join() + + assert evaluateExprInFreshProcess(MODULES, 'x3.h()', compilerCacheDir) == 9 + + +RACE_CONDITION_MODULES = {'x1.py': "\n".join([ + "@Entrypoint", + "def f():", + " x = Dict(int, int)()", + " x[3] = 4", + " return x[3]"]), 'x2.py': "\n".join([ + "import x1", + "@Entrypoint", + "def g():", + " x = Dict(int, int)()", + " x[3] = 5", + " return x[3], x1.f()" + ]), 'x3.py': "\n".join([ + "from x1 import f", + "from x2 import g", + "@Entrypoint", + "def h():", + " return f() + g()[0]" + ])} + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_compiler_cache_race_condition_triangle(): + """Given double-compilation of f, and g depends on f, test we can load the cached f and compile g.""" + # compile a module containing f + with tempfile.TemporaryDirectory() as dir1: + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x1.f()", dir1) + cache_dir = MockDirectory.fromPath(dir1) + + # recompile the f-module in a new cache. + with tempfile.TemporaryDirectory() as dir2: + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x1.f()", dir2) + # merge the two caches to simulate a race condition + cache_dir.dumpInto(dir2) + assert len(os.listdir(dir2)) == 2 + # compile a module containing g that depends on f. Run f and g in fresh processes. + assert evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x2.g()", dir2) == (5, 4) + assert evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x1.f()", dir2) == 4 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_compiler_cache_race_condition_diamond(): + """Given single-compilation of f, and double-compilation of g, and h depends on g, test we + can load the cached f and g and compile h. + """ + # one module with f + with tempfile.TemporaryDirectory() as dir1: + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x1.f()", dir1) + cache_dir = MockDirectory.fromPath(dir1) + # two gs, both depending on f + with tempfile.TemporaryDirectory() as dir2: + cache_dir.dumpInto(dir2) + assert len(os.listdir(dir2)) == 1 + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x2.g()", dir2) + cache_dir_2 = MockDirectory.fromPath(dir2) + with tempfile.TemporaryDirectory() as dir3: + cache_dir.dumpInto(dir3) + assert len(os.listdir(dir3)) == 1 + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x2.g()", dir3) + cache_dir_3 = MockDirectory.fromPath(dir3) + + with tempfile.TemporaryDirectory() as dir4: + # we should now have three modules, one with f and two with g. + cache_dir_3.dumpInto(dir4) + cache_dir_2.dumpInto(dir4) + assert len(os.listdir(dir4)) == 3 + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x3.h()", dir3) == 9 + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x2.g()", dir3) == (5, 4) + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x1.f()", dir3) == 4 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_compiler_cache_race_condition_duplicate_edges(): + """Given two modules that both contain f and g, and h depends on g, test we + can load the cached f and g and compile h. + """ + # two modules both with f and g + with tempfile.TemporaryDirectory() as dir1: + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x2.g()", dir1) + cache_dir = MockDirectory.fromPath(dir1) + module = os.listdir(dir1)[0] + with open(os.path.join(dir1, module, "name_manifest.txt"), "r") as flines: + manifest = flines.read() + assert "tp.g" in manifest and "tp.f" in manifest + + with tempfile.TemporaryDirectory() as dir2: + evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x2.g()", dir2) + cache_dir.dumpInto(dir2) + assert len(os.listdir(dir2)) == 2 + assert evaluateExprInFreshProcess(RACE_CONDITION_MODULES, "x3.h()", dir2) == 9 + + @pytest.mark.skipif('sys.platform=="darwin"') def test_compiler_cache_handles_class_destructors_correctly(): xmodule = "\n".join([ @@ -362,12 +567,9 @@ def test_compiler_cache_handles_changed_types(): assert evaluateExprInFreshProcess(VERSION2, 'x.f(1)', compilerCacheDir) == 1 assert len(os.listdir(compilerCacheDir)) == 2 - badCt = 0 - for subdir in os.listdir(compilerCacheDir): - if 'marked_invalid' in os.listdir(os.path.join(compilerCacheDir, subdir)): - badCt += 1 - - assert badCt == 1 + # if we then use g1 again, it should not have been marked invalid and so remains accessible. + assert evaluateExprInFreshProcess(VERSION1, 'x.g1(1)', compilerCacheDir) == 1 + assert len(os.listdir(compilerCacheDir)) == 2 @pytest.mark.skipif('sys.platform=="darwin"') @@ -395,3 +597,95 @@ def test_ordering_is_stable_under_code_change(): ) assert names == names2 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_compiler_cache_avoids_deserialization_error(): + xmodule1 = "\n".join([ + "@Entrypoint", + "def f():", + " return None", + "import badModule", + "@Entrypoint", + "def g():", + " print(badModule)", + " return f()", + ]) + + xmodule2 = "\n".join([ + "@Entrypoint", + "def f():", + " return", + ]) + + VERSION1 = {'x.py': xmodule1, 'badModule.py': ''} + VERSION2 = {'x.py': xmodule2} + + with tempfile.TemporaryDirectory() as compilerCacheDir: + evaluateExprInFreshProcess(VERSION1, 'x.g()', compilerCacheDir) + assert len(os.listdir(compilerCacheDir)) == 1 + evaluateExprInFreshProcess(VERSION2, 'x.f()', compilerCacheDir) + assert len(os.listdir(compilerCacheDir)) == 2 + evaluateExprInFreshProcess(VERSION1, 'x.g()', compilerCacheDir) + assert len(os.listdir(compilerCacheDir)) == 2 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_compiler_cache_can_handle_cyclic_dependency_graph(): + xmodule = """ + @Entrypoint + def getX(): + # class C: + def f(): + var = g(0) + return var + + def g(x): + if x == 0: + return 1 + else: + return f() + + return f() + """.replace('\n ', '\n') + MODULE = {'x.py': xmodule} + + with tempfile.TemporaryDirectory() as compilerCacheDir: + # run twice to check cached code can be retrieved + assert evaluateExprInFreshProcess(MODULE, 'x.getX()', compilerCacheDir) == 1 + assert evaluateExprInFreshProcess(MODULE, 'x.getX()', compilerCacheDir) == 1 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_compiler_cache_throws_on_import_loop(): + """It is possible, when compiling a module, to attempt to deserialise + a callTarget containing a module import which runs an Entrypointed function. + This results in a 'compilation loop' where one iteration of the conversion + is waiting on another, which currently breaks our model of the compilation + process. + """ + module1 = """ + @Entrypoint + def f(x): + return x+1 + f(1) + """.replace('\n ', '\n') + module2 = """ + @Entrypoint + def g(): + import x + """.replace('\n ', '\n') + module3 = """ + import y + def rung(): + try: + y.g() + except ImportError as e: + return 'ImportError caught' + rung() + """.replace('\n ', '\n') + with tempfile.TemporaryDirectory() as compilerCacheDir: + + evaluateExprInFreshProcess({'x.py': module1}, 'x.f(1)', compilerCacheDir) + exception_string = evaluateExprInFreshProcess({'z.py': module3, 'y.py': module2, 'x.py': module1, }, 'z.rung()', compilerCacheDir) + assert exception_string == 'ImportError caught' diff --git a/typed_python/compiler/global_variable_definition.py b/typed_python/compiler/global_variable_definition.py index 4f01c11f..4dbf34f8 100644 --- a/typed_python/compiler/global_variable_definition.py +++ b/typed_python/compiler/global_variable_definition.py @@ -79,3 +79,12 @@ def __init__(self, name, typ, metadata): self.name = name self.type = typ self.metadata = metadata + + def __eq__(self, other): + if not isinstance(other, GlobalVariableDefinition): + return False + + return self.name == other.name and self.type == other.type and self.metadata == other.metadata + + def __str__(self): + return f"GlobalVariableDefinition(name={self.name}, type={self.type}, metadata={pad(str(self.metadata))})" diff --git a/typed_python/compiler/llvm_compiler.py b/typed_python/compiler/llvm_compiler.py index 9579df16..f33e5edb 100644 --- a/typed_python/compiler/llvm_compiler.py +++ b/typed_python/compiler/llvm_compiler.py @@ -22,7 +22,7 @@ from typed_python.compiler.binary_shared_object import BinarySharedObject import ctypes -from typed_python import _types +from typed_python import _types, SerializationContext llvm.initialize() llvm.initialize_native_target() @@ -84,18 +84,14 @@ def create_execution_engine(inlineThreshold): class Compiler: - def __init__(self, inlineThreshold): + def __init__(self, inlineThreshold, compilerCache): self.engine, self.module_pass_manager = create_execution_engine(inlineThreshold) - self.converter = native_ast_to_llvm.Converter() + self.converter = native_ast_to_llvm.Converter(compilerCache) self.functions_by_name = {} self.inlineThreshold = inlineThreshold self.verbose = False self.optimize = True - def markExternal(self, functionNameToType): - """Provide type signatures for a set of external functions.""" - self.converter.markExternal(functionNameToType) - def mark_converter_verbose(self): self.converter.verbose = True @@ -121,17 +117,20 @@ def buildSharedObject(self, functions): self.engine.finalize_object() + serializedGlobalVariableDefinitions = {x: SerializationContext().serialize(y) for x, y in module.globalVariableDefinitions.items()} + return BinarySharedObject.fromModule( mod, - module.globalVariableDefinitions, + serializedGlobalVariableDefinitions, module.functionNameToType, + module.globalDependencies ) def function_pointer_by_name(self, name): return self.functions_by_name.get(name) def buildModule(self, functions): - """Compile a list of functions into a new module. + """Compile a list of functions into a new module. Only relevant if there is no compiler cache. Args: functions - a map from name to native_ast.Function @@ -187,4 +186,5 @@ def buildModule(self, functions): ) ) - return LoadedModule(native_function_pointers, module.globalVariableDefinitions) + serializedGlobalVariableDefinitions = {x: SerializationContext().serialize(y) for x, y in module.globalVariableDefinitions.items()} + return LoadedModule(native_function_pointers, serializedGlobalVariableDefinitions) diff --git a/typed_python/compiler/llvm_compiler_test.py b/typed_python/compiler/llvm_compiler_test.py index e10f9453..d914bae4 100644 --- a/typed_python/compiler/llvm_compiler_test.py +++ b/typed_python/compiler/llvm_compiler_test.py @@ -20,6 +20,8 @@ from typed_python.compiler.module_definition import ModuleDefinition from typed_python.compiler.global_variable_definition import GlobalVariableMetadata +from typed_python.test_util import evaluateExprInFreshProcess + import pytest import ctypes @@ -115,7 +117,7 @@ def test_create_binary_shared_object(): {'__test_f_2': f} ) - assert len(bso.globalVariableDefinitions) == 1 + assert len(bso.serializedGlobalVariableDefinitions) == 1 with tempfile.TemporaryDirectory() as tf: loaded = bso.load(tf) @@ -131,3 +133,28 @@ def test_create_binary_shared_object(): pointers[0].set(5) assert loaded.functionPointers['__test_f_2']() == 5 + + +@pytest.mark.skipif('sys.platform=="darwin"') +def test_loaded_modules_persist(): + """ + Make sure that loaded modules are persisted in the converter state. + + We have to maintain these references to avoid surprise segfaults - if this test fails, + it should be because the GlobalVariableDefinition memory management has been refactored. + """ + + # compile a module + xmodule = "\n".join([ + "@Entrypoint", + "def f(x):", + " return x + 1", + "@Entrypoint", + "def g(x):", + " return f(x) * 100", + "g(1000)", + "def get_loaded_modules():", + " return len(Runtime.singleton().converter.loadedUncachedModules)" + ]) + VERSION1 = {'x.py': xmodule} + assert evaluateExprInFreshProcess(VERSION1, 'x.get_loaded_modules()') == 1 diff --git a/typed_python/compiler/loaded_module.py b/typed_python/compiler/loaded_module.py index c03ab321..ffb2112c 100644 --- a/typed_python/compiler/loaded_module.py +++ b/typed_python/compiler/loaded_module.py @@ -1,39 +1,48 @@ +from typing import Dict, List from typed_python.compiler.module_definition import ModuleDefinition -from typed_python import PointerTo, ListOf, Class +from typed_python import PointerTo, ListOf, Class, SerializationContext from typed_python import _types class LoadedModule: """Represents a bundle of compiled functions that are now loaded in memory. - Members: functionPointers - a map from name to NativeFunctionPointer giving the public interface of the module - globalVariableDefinitions - a map from name to GlobalVariableDefinition + serializedGlobalVariableDefinitions - a map from LLVM-assigned global name to serialized GlobalVariableDefinition giving the loadable strings """ GET_GLOBAL_VARIABLES_NAME = ModuleDefinition.GET_GLOBAL_VARIABLES_NAME - def __init__(self, functionPointers, globalVariableDefinitions): + def __init__(self, functionPointers, serializedGlobalVariableDefinitions): self.functionPointers = functionPointers + assert ModuleDefinition.GET_GLOBAL_VARIABLES_NAME in self.functionPointers + + self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions + self.orderedDefs = [ + self.serializedGlobalVariableDefinitions[name] for name in sorted(self.serializedGlobalVariableDefinitions) + ] + self.orderedDefNames = sorted(list(self.serializedGlobalVariableDefinitions.keys())) + self.pointers = ListOf(PointerTo(int))() + self.pointers.resize(len(self.orderedDefs)) - self.globalVariableDefinitions = globalVariableDefinitions + self.functionPointers[ModuleDefinition.GET_GLOBAL_VARIABLES_NAME](self.pointers.pointerUnsafe(0)) + + self.installedGlobalVariableDefinitions = {} @staticmethod - def validateGlobalVariables(globalVariableDefinitions): + def validateGlobalVariables(serializedGlobalVariableDefinitions: Dict[str, bytes]) -> bool: """Check that each global variable definition is sensible. - Sometimes we may successfully deserialize a global variable from a cached module, but then some dictionary member is not valid because it was removed or has the wrong type. In this case, we need to evict this module from the cache because it's no longer valid. Args: - globalVariableDefinitions - a dict from string to GlobalVariableMetadata + serializedGlobalVariableDefinitions: a dict from string to a serialized GlobalVariableMetadata """ - for gvd in globalVariableDefinitions.values(): - meta = gvd.metadata - + for gvd in serializedGlobalVariableDefinitions.values(): + meta = SerializationContext().deserialize(gvd).metadata if meta.matches.PointerToTypedPythonObjectAsMemberOfDict: if not isinstance(meta.sourceDict, dict): return False @@ -54,54 +63,47 @@ def validateGlobalVariables(globalVariableDefinitions): return True - def linkGlobalVariables(self): - """Walk over all global variables in the module and make sure they are populated. - + def linkGlobalVariables(self, variable_names: List[str] = None) -> None: + """Walk over all global variables in `variable_names` and make sure they are populated. Each module has a bunch of global variables that contain references to things like type objects, string objects, python module members, etc. - - The metadata about these is stored in 'self.globalVariableDefinitions' whose keys + The metadata about these is stored in 'self.serializedGlobalVariableDefinitions' whose keys are names and whose values are GlobalVariableMetadata instances. - Every module we compile exposes a member named ModuleDefinition.GET_GLOBAL_VARIABLES_NAME which takes a pointer to a list of pointers and fills it out with the global variables. - When the module is loaded, all the variables are initialized to zero. This function walks over them and populates them, effectively linking them into the current binary. """ - assert ModuleDefinition.GET_GLOBAL_VARIABLES_NAME in self.functionPointers - - orderedDefs = [ - self.globalVariableDefinitions[name] for name in sorted(self.globalVariableDefinitions) - ] - pointers = ListOf(PointerTo(int))() - pointers.resize(len(orderedDefs)) + if variable_names is None: + i_vals = range(len(self.orderedDefs)) + else: + i_vals = [self.orderedDefNames.index(x) for x in variable_names] - self.functionPointers[ModuleDefinition.GET_GLOBAL_VARIABLES_NAME](pointers.pointerUnsafe(0)) + for i in i_vals: + assert self.pointers[i], f"Failed to get a pointer to {self.orderedDefs[i].name}" - for i in range(len(orderedDefs)): - assert pointers[i], f"Failed to get a pointer to {orderedDefs[i].name}" + meta = SerializationContext().deserialize(self.orderedDefs[i]).metadata - meta = orderedDefs[i].metadata + self.installedGlobalVariableDefinitions[i] = meta if meta.matches.StringConstant: - pointers[i].cast(str).initialize(meta.value) + self.pointers[i].cast(str).initialize(meta.value) if meta.matches.IntegerConstant: - pointers[i].cast(int).initialize(meta.value) + self.pointers[i].cast(int).initialize(meta.value) elif meta.matches.BytesConstant: - pointers[i].cast(bytes).initialize(meta.value) + self.pointers[i].cast(bytes).initialize(meta.value) elif meta.matches.PointerToPyObject: - pointers[i].cast(object).initialize(meta.value) + self.pointers[i].cast(object).initialize(meta.value) elif meta.matches.PointerToTypedPythonObject: - pointers[i].cast(meta.type).initialize(meta.value) + self.pointers[i].cast(meta.type).initialize(meta.value) elif meta.matches.PointerToTypedPythonObjectAsMemberOfDict: - pointers[i].cast(meta.type).initialize(meta.sourceDict[meta.name]) + self.pointers[i].cast(meta.type).initialize(meta.sourceDict[meta.name]) elif meta.matches.ClassMethodDispatchSlot: slotIx = _types.allocateClassMethodDispatch( @@ -111,17 +113,17 @@ def linkGlobalVariables(self): meta.argTupleType, meta.kwargTupleType ) - pointers[i].cast(int).initialize(slotIx) + self.pointers[i].cast(int).initialize(slotIx) elif meta.matches.IdOfPyObject: - pointers[i].cast(int).initialize(id(meta.value)) + self.pointers[i].cast(int).initialize(id(meta.value)) elif meta.matches.ClassVtable: - pointers[i].cast(int).initialize( + self.pointers[i].cast(int).initialize( _types._vtablePointer(meta.value) ) elif meta.matches.RawTypePointer: - pointers[i].cast(int).initialize( + self.pointers[i].cast(int).initialize( _types.getTypePointer(meta.value) ) diff --git a/typed_python/compiler/module_definition.py b/typed_python/compiler/module_definition.py index 4e9b35fd..cadbb2ec 100644 --- a/typed_python/compiler/module_definition.py +++ b/typed_python/compiler/module_definition.py @@ -18,15 +18,19 @@ class ModuleDefinition: """A single module of compiled llvm code. - Members: - moduleText - a string containing the llvm IR for the module - functionList - a list of the names of exported functions - globalDefinitions - a dict from name to a GlobalDefinition + Attributes: + moduleText (str): a string containing the llvm IR for the module + functionList (list): a list of the names of exported functions + globalDefinitions (dict): a dict from name to a GlobalDefinition + globalDependencies (dict): a dict from function link_name to a list of globals the + function depends on + hash (str): The module hash, generated from the llvm IR. """ GET_GLOBAL_VARIABLES_NAME = ".get_global_variables" - def __init__(self, moduleText, functionNameToType, globalVariableDefinitions): + def __init__(self, moduleText, functionNameToType, globalVariableDefinitions, globalDependencies): self.moduleText = moduleText self.functionNameToType = functionNameToType self.globalVariableDefinitions = globalVariableDefinitions + self.globalDependencies = globalDependencies self.hash = sha_hash(moduleText) diff --git a/typed_python/compiler/native_ast_to_llvm.py b/typed_python/compiler/native_ast_to_llvm.py index 4850ed95..bbd2027d 100644 --- a/typed_python/compiler/native_ast_to_llvm.py +++ b/typed_python/compiler/native_ast_to_llvm.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typed_python.compiler.native_ast as native_ast -from typed_python.compiler.module_definition import ModuleDefinition -from typed_python.compiler.global_variable_definition import GlobalVariableDefinition import llvmlite.ir import os - +import typed_python.compiler.native_ast as native_ast +from typed_python.compiler.global_variable_definition import GlobalVariableDefinition +from typed_python.compiler.module_definition import ModuleDefinition +from typing import Dict llvm_i8ptr = llvmlite.ir.IntType(8).as_pointer() llvm_i8 = llvmlite.ir.IntType(8) llvm_i32 = llvmlite.ir.IntType(32) @@ -501,19 +501,19 @@ def __init__(self, module, globalDefinitions, globalDefinitionLlvmValues, - function, converter, builder, arg_assignments, output_type, - external_function_references + external_function_references, + compilerCache, ): - self.function = function # dict from name to GlobalVariableDefinition self.globalDefinitions = globalDefinitions self.globalDefinitionLlvmValues = globalDefinitionLlvmValues - + # a list of the global LLVM names that the function depends on. + self.global_names = [] self.module = module self.converter = converter self.builder = builder @@ -522,6 +522,7 @@ def __init__(self, self.external_function_references = external_function_references self.tags_initialized = {} self.stack_slots = {} + self.compilerCache = compilerCache def tags_as(self, new_tags): class scoper(): @@ -631,7 +632,16 @@ def generate_exception_and_store_value(self, llvm_pointer_val): ) return self.builder.bitcast(exception_ptr, llvm_i8ptr) - def namedCallTargetToLLVM(self, target): + def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVMValue: + """ + Generate llvm IR code for a given target. + + There are three options for code generation: + 1. The target is external, i.e something like pyobj_len, np_add_traceback - system-level functions. We add to + external_function_references. + 2. The function is in function_definitions, in which case we grab the function definition and make an inlining decision. + 3. We have a compiler cache, and the function is in it. We add to external_function_references. + """ if target.external: if target.name not in self.external_function_references: func_type = llvmlite.ir.FunctionType( @@ -648,7 +658,23 @@ def namedCallTargetToLLVM(self, target): llvmlite.ir.Function(self.module, func_type, target.name) func = self.external_function_references[target.name] - elif target.name in self.converter._externallyDefinedFunctionTypes: + elif target.name in self.converter._function_definitions: + func = self.converter._functions_by_name[target.name] + if func.module is not self.module: + # first, see if we'd like to inline this module + if ( + self.converter.totalFunctionComplexity(target.name) < CROSS_MODULE_INLINE_COMPLEXITY + ): + func = self.converter.repeatFunctionInModule(target.name, self.module) + else: + if target.name not in self.external_function_references: + self.external_function_references[target.name] = \ + llvmlite.ir.Function(self.module, func.function_type, func.name) + + func = self.external_function_references[target.name] + else: + # TODO (Will): decide whether to inline cached code + assert self.compilerCache is not None and self.compilerCache.hasSymbol(target.name) # this function is defined in a shared object that we've loaded from a prior # invocation if target.name not in self.external_function_references: @@ -665,22 +691,6 @@ def namedCallTargetToLLVM(self, target): ) func = self.external_function_references[target.name] - else: - func = self.converter._functions_by_name[target.name] - - if func.module is not self.module: - # first, see if we'd like to inline this module - if ( - self.converter.totalFunctionComplexity(target.name) < CROSS_MODULE_INLINE_COMPLEXITY - and self.converter.canBeInlined(target.name) - ): - func = self.converter.repeatFunctionInModule(target.name, self.module) - else: - if target.name not in self.external_function_references: - self.external_function_references[target.name] = \ - llvmlite.ir.Function(self.module, func.function_type, func.name) - - func = self.external_function_references[target.name] return TypedLLVMValue( func, @@ -801,6 +811,7 @@ def _convert(self, expr): return self.stack_slots[expr.name] if expr.matches.GlobalVariable: + self.global_names.append(expr.name) if expr.name in self.globalDefinitions: assert expr.metadata == self.globalDefinitions[expr.name].metadata, ( expr.metadata, self.globalDefinitions[expr.name].metadata @@ -1484,15 +1495,11 @@ def define(fname, output, inputs, vararg=False): class Converter: - def __init__(self): + def __init__(self, compilerCache=None): object.__init__(self) self._modules = {} - self._functions_by_name = {} - self._function_definitions = {} - - # a map from function name to function type for functions that - # are defined in external shared objects and linked in to this one. - self._externallyDefinedFunctionTypes = {} + self._functions_by_name: Dict[str, llvmlite.ir.Function] = {} + self._function_definitions: Dict[str, native_ast.Function] = {} # total number of instructions in each function, by name self._function_complexity = {} @@ -1502,17 +1509,12 @@ def __init__(self): self._printAllNativeCalls = os.getenv("TP_COMPILER_LOG_NATIVE_CALLS") self.verbose = False - def markExternal(self, functionNameToType): - """Provide type signatures for a set of external functions.""" - self._externallyDefinedFunctionTypes.update(functionNameToType) - - def canBeInlined(self, name): - return name not in self._externallyDefinedFunctionTypes + self.compilerCache = compilerCache def totalFunctionComplexity(self, name): """Return the total number of instructions contained in a function. - The function must already have been defined in a prior parss. We use this + The function must already have been defined in a prior pass. We use this information to decide which functions to repeat in new module definitions. """ if name in self._function_complexity: @@ -1546,9 +1548,7 @@ def repeatFunctionInModule(self, name, module): assert isinstance(funcType, llvmlite.ir.FunctionType) self._functions_by_name[name] = llvmlite.ir.Function(module, funcType, name) - self._inlineRequests.append(name) - return self._functions_by_name[name] def add_functions(self, names_to_definitions): @@ -1604,7 +1604,8 @@ def add_functions(self, names_to_definitions): globalDefinitions = {} globalDefinitionsLlvmValues = {} - + # we need a separate dictionary owing to the possibility of global var reuse across functions. + globalDependencies = {} while names_to_definitions: for name in sorted(names_to_definitions): definition = names_to_definitions.pop(name) @@ -1628,12 +1629,12 @@ def add_functions(self, names_to_definitions): module, globalDefinitions, globalDefinitionsLlvmValues, - func, self, builder, arg_assignments, definition.output_type, - external_function_references + external_function_references, + self.compilerCache, ) func_converter.setup() @@ -1642,6 +1643,8 @@ def add_functions(self, names_to_definitions): func_converter.finalize() + globalDependencies[func.name] = func_converter.global_names + if res is not None: assert res.llvm_value is None if definition.output_type != native_ast.Void: @@ -1675,7 +1678,8 @@ def add_functions(self, names_to_definitions): return ModuleDefinition( str(module), functionTypes, - globalDefinitions + globalDefinitions, + globalDependencies ) def defineGlobalMetadataAccessor(self, module, globalDefinitions, globalDefinitionsLlvmValues): diff --git a/typed_python/compiler/python_to_native_converter.py b/typed_python/compiler/python_to_native_converter.py index c9bb2748..f8a99f40 100644 --- a/typed_python/compiler/python_to_native_converter.py +++ b/typed_python/compiler/python_to_native_converter.py @@ -17,6 +17,7 @@ from typed_python.hash import Hash from types import ModuleType +from typing import Dict from typed_python import Class import typed_python.python_ast as python_ast import typed_python._types as _types @@ -72,19 +73,21 @@ def getNextDirtyNode(self): return identity - def addRoot(self, identity): + def addRoot(self, identity, dirty=True): if identity not in self._identity_levels: self._identity_levels[identity] = 0 - self.markDirty(identity) + if dirty: + self.markDirty(identity) - def addEdge(self, caller, callee): + def addEdge(self, caller, callee, dirty=True): if caller not in self._identity_levels: raise Exception(f"unknown identity {caller} found in the graph") if callee not in self._identity_levels: self._identity_levels[callee] = self._identity_levels[caller] + 1 - self.markDirty(callee, isNew=True) + if dirty: + self.markDirty(callee, isNew=True) self._dependencies.addEdge(caller, callee) @@ -122,21 +125,21 @@ def __init__(self, llvmCompiler, compilerCache): self.llvmCompiler = llvmCompiler self.compilerCache = compilerCache + # all LoadedModule objects that we have created. We need to keep them alive so + # that any python metadata objects the've created stay alive as well. Ultimately, this + # may not be the place we put these objects (for instance, you could imagine a + # 'dummy' compiler cache or something). But for now, we need to keep them alive. + self.loadedUncachedModules = [] + # if True, then insert additional code to check for undefined behavior. self.generateDebugChecks = False - # all link names for which we have a definition. - self._allDefinedNames = set() - - # all names we loaded from the cache - self._allCachedNames = set() - self._link_name_for_identity = {} self._identity_for_link_name = {} - self._definitions = {} - self._targets = {} + self._definitions: Dict[str, native_ast.Function] = {} + self._targets: Dict[str, TypedCallTarget] = {} self._inflight_definitions = {} - self._inflight_function_conversions = {} + self._inflight_function_conversions: Dict[str, FunctionConversionContext] = {} self._identifier_to_pyfunc = {} self._times_calculated = {} @@ -186,33 +189,37 @@ def identityToName(self, identity): return self._link_name_for_identity.get(identity) def buildAndLinkNewModule(self): - targets = self.extract_new_function_definitions() + definitions = self.extract_new_function_definitions() - if not targets: + if not definitions: return if self.compilerCache is None: - loadedModule = self.llvmCompiler.buildModule(targets) + loadedModule = self.llvmCompiler.buildModule(definitions) loadedModule.linkGlobalVariables() + self.loadedUncachedModules.append(loadedModule) return # get a set of function names that we depend on externallyUsed = set() + dependency_edgelist = [] - for funcName in targets: + for funcName in definitions: ident = self._identity_for_link_name.get(funcName) if ident is not None: for dep in self._dependencies.getNamesDependedOn(ident): depLN = self._link_name_for_identity.get(dep) - if depLN not in targets: + dependency_edgelist.append([funcName, depLN]) + if depLN not in definitions: externallyUsed.add(depLN) - binary = self.llvmCompiler.buildSharedObject(targets) + binary = self.llvmCompiler.buildSharedObject(definitions) self.compilerCache.addModule( binary, - {name: self._targets[name] for name in targets if name in self._targets}, - externallyUsed + {name: self.getTarget(name) for name in definitions if self.hasTarget(name)}, + externallyUsed, + dependency_edgelist ) def extract_new_function_definitions(self): @@ -226,7 +233,7 @@ def extract_new_function_definitions(self): return res - def identityHashToLinkerName(self, name, identityHash, prefix="tp."): + def identityHashToFunctionName(self, name, identityHash, prefix="tp."): assert isinstance(name, str) assert isinstance(identityHash, str) assert isinstance(prefix, str) @@ -274,69 +281,70 @@ def defineLinkName(self, identity, linkName): self._link_name_for_identity[identity] = linkName self._identity_for_link_name[linkName] = identity - if linkName in self._allDefinedNames: - return False - - self._allDefinedNames.add(linkName) - - self._loadFromCompilerCache(linkName) + def hasTarget(self, linkName): + return self.getTarget(linkName) is not None - return True + def deleteTarget(self, linkName): + self._targets.pop(linkName) - def _loadFromCompilerCache(self, linkName): - if self.compilerCache: - if self.compilerCache.hasSymbol(linkName): - callTargetsAndTypes = self.compilerCache.loadForSymbol(linkName) + def setTarget(self, linkName, target): + assert(isinstance(target, TypedCallTarget)) + self._targets[linkName] = target - if callTargetsAndTypes is not None: - newTypedCallTargets, newNativeFunctionTypes = callTargetsAndTypes + def getTarget(self, linkName) -> TypedCallTarget: + if linkName in self._targets: + return self._targets[linkName] - self._targets.update(newTypedCallTargets) - self.llvmCompiler.markExternal(newNativeFunctionTypes) + if self.compilerCache is not None and self.compilerCache.hasSymbol(linkName): + return self.compilerCache.getTarget(linkName) - self._allDefinedNames.update(newNativeFunctionTypes) - self._allCachedNames.update(newNativeFunctionTypes) + return None def defineNonPythonFunction(self, name, identityTuple, context): """Define a non-python generating function (if we haven't defined it before already) name - the name to actually give the function. identityTuple - a unique (sha)hashable tuple - context - a FunctionConvertsionContext lookalike + context - a FunctionConversionContext lookalike returns a TypedCallTarget, or None if it's not known yet """ identity = self.hashObjectToIdentity(identityTuple).hexdigest - linkName = self.identityHashToLinkerName(name, identity, "runtime.") + linkName = self.identityHashToFunctionName(name, identity, "runtime.") self.defineLinkName(identity, linkName) + target = self.getTarget(linkName) + if self._currentlyConverting is not None: - self._dependencies.addEdge(self._currentlyConverting, identity) + self._dependencies.addEdge(self._currentlyConverting, identity, dirty=(target is None)) else: - self._dependencies.addRoot(identity) + self._dependencies.addRoot(identity, dirty=(target is None)) - if linkName in self._targets: - return self._targets.get(linkName) + if target is not None: + return target self._inflight_function_conversions[identity] = context if context.knownOutputType() is not None or context.alwaysRaises(): - self._targets[linkName] = self.getTypedCallTarget( - name, - context.getInputTypes(), - context.knownOutputType(), - alwaysRaises=context.alwaysRaises(), - functionMetadata=context.functionMetadata + self.setTarget( + linkName, + self.getTypedCallTarget( + name, + context.getInputTypes(), + context.knownOutputType(), + alwaysRaises=context.alwaysRaises(), + functionMetadata=context.functionMetadata, + ) ) if self._currentlyConverting is None: # force the function to resolve immediately self._resolveAllInflightFunctions() - self._installInflightFunctions(name) + self._installInflightFunctions() self._inflight_function_conversions.clear() - return self._targets.get(linkName) + return self.getTarget(linkName) def defineNativeFunction(self, name, identity, input_types, output_type, generatingFunction): """Define a native function if we haven't defined it before already. @@ -459,13 +467,19 @@ def generateCallConverter(self, callTarget: TypedCallTarget): identifier = "call_converter_" + callTarget.name linkName = callTarget.name + ".dispatch" - if linkName in self._allDefinedNames: + # # we already made a definition for this in this process so don't do it again + if linkName in self._definitions: return linkName - self._loadFromCompilerCache(linkName) - if linkName in self._allDefinedNames: + # # we already defined it in another process so don't do it again + if self.compilerCache is not None and self.compilerCache.hasSymbol(linkName): return linkName + # N.B. there aren't targets for call converters. We make the definition directly. + + # if self.getTarget(linkName): + # return linkName + args = [] for i in range(len(callTarget.input_types)): if not callTarget.input_types[i].is_empty: @@ -503,7 +517,6 @@ def generateCallConverter(self, callTarget: TypedCallTarget): self._link_name_for_identity[identifier] = linkName self._identity_for_link_name[linkName] = identifier - self._allDefinedNames.add(linkName) self._definitions[linkName] = definition self._new_native_functions.add(linkName) @@ -516,10 +529,6 @@ def _resolveAllInflightFunctions(self): if not identity: return - linkName = self._link_name_for_identity[identity] - if linkName in self._allCachedNames: - continue - functionConverter = self._inflight_function_conversions[identity] hasDefinitionBeforeConversion = identity in self._inflight_definitions @@ -529,6 +538,8 @@ def _resolveAllInflightFunctions(self): self._times_calculated[identity] = self._times_calculated.get(identity, 0) + 1 + # this calls back into convert with dependencies + # they get registered as dirty nativeFunction, actual_output_type = functionConverter.convertToNativeFunction() if nativeFunction is not None: @@ -537,9 +548,8 @@ def _resolveAllInflightFunctions(self): for i in self._inflight_function_conversions: if i in self._link_name_for_identity: name = self._link_name_for_identity[i] - if name in self._targets: - self._targets.pop(name) - self._allDefinedNames.discard(name) + if self.hasTarget(name): + self.deleteTarget(name) ln = self._link_name_for_identity.pop(i) self._identity_for_link_name.pop(ln) @@ -567,12 +577,15 @@ def _resolveAllInflightFunctions(self): name = self._link_name_for_identity[identity] - self._targets[name] = self.getTypedCallTarget( + self.setTarget( name, - functionConverter._input_types, - actual_output_type, - alwaysRaises=functionConverter.alwaysRaises(), - functionMetadata=functionConverter.functionMetadata + self.getTypedCallTarget( + name, + functionConverter._input_types, + actual_output_type, + alwaysRaises=functionConverter.alwaysRaises(), + functionMetadata=functionConverter.functionMetadata, + ), ) if dirtyUpstream: @@ -670,22 +683,22 @@ def compileClassDestructor(self, cls): _types.installClassDestructor(cls, fp.fp) self._installedDestructors.add(cls) - def functionPointerByName(self, linkerName) -> NativeFunctionPointer: + def functionPointerByName(self, func_name) -> NativeFunctionPointer: """Find a NativeFunctionPointer for a given link-time name. Args: - linkerName (str) - the name of the compiled symbol we want + func_name (str) - the name of the compiled symbol we want. Returns: a NativeFunctionPointer or None """ if self.compilerCache is None: # the llvm compiler holds it all - return self.llvmCompiler.function_pointer_by_name(linkerName) + return self.llvmCompiler.function_pointer_by_name(func_name) else: # the llvm compiler is just building shared objects, but the # compiler cache has all the pointers. - return self.compilerCache.function_pointer_by_name(linkerName) + return self.compilerCache.function_pointer_by_name(func_name) def convertTypedFunctionCall(self, functionType, overloadIx, inputWrappers, assertIsRoot=False): overload = functionType.overloads[overloadIx] @@ -846,7 +859,7 @@ def convert( identity = identityHash.hexdigest - name = self.identityHashToLinkerName(funcName, identity) + name = self.identityHashToFunctionName(funcName, identity) self.defineLinkName(identity, name) @@ -860,13 +873,15 @@ def convert( if assertIsRoot: assert isRoot + target = self.getTarget(name) + if self._currentlyConverting is not None: - self._dependencies.addEdge(self._currentlyConverting, identity) + self._dependencies.addEdge(self._currentlyConverting, identity, dirty=(target is None)) else: - self._dependencies.addRoot(identity) + self._dependencies.addRoot(identity, dirty=(target is None)) - if name in self._targets: - return self._targets[name] + if target is not None: + return target if identity not in self._inflight_function_conversions: functionConverter = self.createConversionContext( @@ -880,14 +895,13 @@ def convert( output_type, conversionType ) - self._inflight_function_conversions[identity] = functionConverter if isRoot: try: self._resolveAllInflightFunctions() - self._installInflightFunctions(name) - return self._targets[name] + self._installInflightFunctions() + return self.getTarget(name) finally: self._inflight_function_conversions.clear() @@ -897,12 +911,12 @@ def convert( # target with an output type and we can return that. Otherwise we have to # return None, which will cause callers to replace this with a throw # until we have had a chance to do a full pass of conversion. - if name in self._targets: - return self._targets[name] - else: - return None + if self.getTarget(name) is not None: + raise RuntimeError(f"Unexpected conversion error for {name}") + return None - def _installInflightFunctions(self, name): + def _installInflightFunctions(self): + """Add all function definitions corresponding to keys in inflight_function_conversions to the relevant dictionaries.""" if VALIDATE_FUNCTION_DEFINITIONS_STABLE: # this should always be true, but its expensive so we have it off by default for identifier, functionConverter in self._inflight_function_conversions.items(): @@ -919,7 +933,11 @@ def _installInflightFunctions(self, name): outboundTargets = [] for outboundFuncId in self._dependencies.getNamesDependedOn(identifier): name = self._link_name_for_identity[outboundFuncId] - outboundTargets.append(self._targets[name]) + target = self.getTarget(name) + if target is not None: + outboundTargets.append(target) + else: + raise RuntimeError(f'dependency not found for {name}.') nativeFunction, actual_output_type = self._inflight_definitions.get(identifier) diff --git a/typed_python/compiler/runtime.py b/typed_python/compiler/runtime.py index 8621147a..a339c2be 100644 --- a/typed_python/compiler/runtime.py +++ b/typed_python/compiler/runtime.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import threading import os import time @@ -207,7 +208,7 @@ def __init__(self): ) else: self.compilerCache = None - self.llvm_compiler = llvm_compiler.Compiler(inlineThreshold=100) + self.llvm_compiler = llvm_compiler.Compiler(inlineThreshold=100, compilerCache=self.compilerCache) self.converter = python_to_native_converter.PythonToNativeConverter( self.llvm_compiler, self.compilerCache @@ -501,6 +502,14 @@ def Entrypoint(pyFunc): if not callable(typedFunc): raise Exception(f"Can only compile functions, not {typedFunc}") + # check if we are already in the middle of the compilation process, due to the Entrypointed + # code being called through a module import, and throw an error if so. + if is_importing(): + compiling_func = Runtime.singleton().converter._currentlyConverting + compiling_func_link_name = Runtime.singleton().converter._link_name_for_identity[compiling_func] + error_message = f"Can't import Entrypointed code {pyFunc.__module__}.{pyFunc.__qualname__} \ + while {compiling_func_link_name} is being compiled." + raise ImportError(error_message) typedFunc = Function(typedFunc) @@ -534,3 +543,20 @@ def Compiled(pyFunc): f.resultTypeFor(*types) return f + + +def is_importing(): + """Walk the stack to check if we are currently importing a module. + + In this case, we will have an 'importlib' between two 'typed_python.compiler.runtime' frames. + """ + in_runtime = False + assert __name__ == 'typed_python.compiler.runtime', 'is_importing() should only be called from typed_python.compiler.runtime' + for frame, *_ in inspect.stack()[::-1]: + frame_name = frame.f_globals.get("__name__") + if frame_name == 'typed_python.compiler.runtime': + in_runtime = True + if in_runtime and frame_name == 'importlib': + return True + + return False diff --git a/typed_python/compiler/tests/numpy_interaction_test.py b/typed_python/compiler/tests/numpy_interaction_test.py index f15bfea9..db774c6d 100644 --- a/typed_python/compiler/tests/numpy_interaction_test.py +++ b/typed_python/compiler/tests/numpy_interaction_test.py @@ -1,4 +1,4 @@ -from typed_python import ListOf, Entrypoint +from typed_python import ListOf, Entrypoint, SerializationContext import numpy import numpy.linalg @@ -44,3 +44,12 @@ def test_listof_from_sliced_numpy_array(): y = x[::2] assert ListOf(int)(y) == [0, 2] + + +def test_can_serialize_numpy_ufunc(): + assert numpy.sin == SerializationContext().deserialize(SerializationContext().serialize(numpy.sin)) + + +def test_can_serialize_numpy_array(): + x = numpy.ones(10) + assert (x == SerializationContext().deserialize(SerializationContext().serialize(x))).all() diff --git a/typed_python/compiler/tests/type_of_instances_compilation_test.py b/typed_python/compiler/tests/type_of_instances_compilation_test.py index 337bdc3f..c3fdf459 100644 --- a/typed_python/compiler/tests/type_of_instances_compilation_test.py +++ b/typed_python/compiler/tests/type_of_instances_compilation_test.py @@ -17,13 +17,13 @@ def typeOfArg(x: C): def test_type_of_alternative_is_specific(): for members in [{}, {'a': int}]: - A = Alternative("A", A=members) + Alt = Alternative("Alt", A=members) @Entrypoint - def typeOfArg(x: A): + def typeOfArg(x: Alt): return type(x) - assert typeOfArg(A.A()) is A.A + assert typeOfArg(Alt.A()) is Alt.A def test_type_of_concrete_alternative_is_specific(): diff --git a/typed_python/types_serialization_test.py b/typed_python/types_serialization_test.py index c5f33c2a..e34da54b 100644 --- a/typed_python/types_serialization_test.py +++ b/typed_python/types_serialization_test.py @@ -15,6 +15,8 @@ import sys import os import importlib +from functools import lru_cache + from abc import ABC, abstractmethod, ABCMeta from typed_python.test_util import callFunctionInFreshProcess import typed_python.compiler.python_ast_util as python_ast_util @@ -57,6 +59,13 @@ module_level_testfun = dummy_test_module.testfunction +class GlobalClassWithLruCache: + @staticmethod + @lru_cache(maxsize=None) + def f(x): + return x + + def moduleLevelFunctionUsedByExactlyOneSerializationTest(): return "please don't touch me" @@ -3061,3 +3070,34 @@ def f(self): print(x) # TODO: make this True # assert x[0].f.__closure__[0].cell_contents is x + + def test_serialize_pyobj_with_custom_reduce(self): + class CustomReduceObject: + def __reduce__(self): + return 'CustomReduceObject' + + assert CustomReduceObject == SerializationContext().deserialize(SerializationContext().serialize(CustomReduceObject)) + + def test_serialize_pyobj_in_MRTG_with_custom_reduce(self): + def getX(): + class InnerCustomReduceObject: + def __reduce__(self): + return 'InnerCustomReduceObject' + + def f(self): + return x + + x = (InnerCustomReduceObject, InnerCustomReduceObject) + + return x + + x = callFunctionInFreshProcess(getX, (), showStdout=True) + + assert x == SerializationContext().deserialize(SerializationContext().serialize(x)) + + def test_serialize_class_static_lru_cache(self): + s = SerializationContext() + + assert ( + s.deserialize(s.serialize(GlobalClassWithLruCache.f)) is GlobalClassWithLruCache.f + )