πŸ“š Blogs
πŸ“‚ A better Python cache for slow function calls

A better Python cache for slow function calls

William Zeng - March 18th, 2024


We wrote a file cache - it's like Python's lru_cache, but it stores the values in files instead of in memory. This has saved us hours of time running our LLM benchmarks, and we'd like to share it as its own Python module. Thanks to Luke Jaggernauth (opens in a new tab) (former Sweep engineer) for building the initial version of this!

Here's the link: https://github.com/sweepai/sweep/blob/main/docs/public/file_cache.py (opens in a new tab). To use it, simply add the file_cache decorator to your function. Here's an example:

import time
from file_cache import file_cache
 
@file_cache()
def slow_function(x, y):
    time.sleep(30)
    return x + y
 
print(slow_function(1, 2)) # -> 3, takes 30 seconds
print(slow_function(1, 2)) # -> 3, takes 0 seconds

Background

We spend a lot of time prompt engineering our agents at Sweep. Our agents take a set of input strings, formats them as a prompt, then sends the prompt and any other information off to the LLM. We chain multiple agents to turn a GitHub issue to a pull request. For example, to modify code we'll input the old code, any relevant context, and instructions then output the new code.

Multiple llm steps being chained

A typical improvement involves tweaking a small part of our pipeline (like improving our planning algorithm), then running the entire pipeline again. We use pdb (python's native debugger) to set breakpoints and inspect the state of our prompts, input values, and parsing logic. For example, we can check whether a certain string matches a regex:

(Pdb) print(todays_date)
'2024-03-14'
(Pdb) re.match(r"^\d{4}-\d{2}-\d{2}$", todays_date)
<re.Match object; span=(0, 10), match='2024-03-14'>

This lets us debug at runtime with the actual data.

Cached pdb

pdb works great, but we have to wait for the entire pipeline to run again. Imagine a version of pdb that not only interrupted execution, but also cached the entire program state up to that point. Our time to hit that same bug could be cut from 10 minutes to 15 seconds (a 40x improvement).

We didn't build this, but we think our file_cache works just as well.

LLM calls are slow but their inputs and outputs are easy to cache, saving a lot of time. We can use the input prompt/string as the cache key, and the output string as the cache value.

What's different from lru_cache?

lru_cache is great for memoizing repeated function calls, but it doesn't support two key features.

  1. we need to persist the cache between runs.

    • lru_cache stores the results in-memory, which means that the next time you run the program the cache will be empty. file_cache stores the results on disk. We also considered using Redis, but writing to disk is easier to set up/manage.
  2. lru_cache doesn't support ignoring arguments that invalidate the cache.

    • We'll use a custom chat_logger which stores the chats for visualization. It contains the current timestamp chat_logger.expiration, which will invalidate the cache if it's serialized.
    • To counteract this we added ignored parameters, used like this: file_cache(ignore_params=["chat_logger"]). This removes chat_logger from the cache key construction and prevents bad invalidation due to the constantly changing expiration.

Implementation

Our two main methods are recursive_hash and file_cache.

recursive_hash

We want to stably hash objects, and this is not natively supported in python (opens in a new tab).

import hashlib
from cache import recursive_hash
 
 
class Obj:
    def __init__(self, name):
        self.name = name
 
obj = Obj("test")
print(recursive_hash(obj)) # -> this works fine
try:
    hashlib.md5(obj).hexdigest()
except Exception as e:
    print(e) # -> this doesn't work

hashlib.md5 alone doesn't work for objects, giving us the error: TypeError: object supporting the buffer API required. We use recursive_hash, which works for arbitrary python objects.

def recursive_hash(value, depth=0, ignore_params=[]):
    """Hash primitives recursively with maximum depth."""
    if depth > MAX_DEPTH:
        return hashlib.md5("max_depth_reached".encode()).hexdigest()
 
    if isinstance(value, (int, float, str, bool, bytes)):
        return hashlib.md5(str(value).encode()).hexdigest()
    elif isinstance(value, (list, tuple)):
        return hashlib.md5(
            "".join(
                [recursive_hash(item, depth + 1, ignore_params) for item in value]
            ).encode()
        ).hexdigest()
    elif isinstance(value, dict):
        return hashlib.md5(
            "".join(
                [
                    recursive_hash(key, depth + 1, ignore_params)
                    + recursive_hash(val, depth + 1, ignore_params)
                    for key, val in value.items()
                    if key not in ignore_params
                ]
            ).encode()
        ).hexdigest()
    elif hasattr(value, "__dict__") and value.__class__.__name__ not in ignore_params:
        return recursive_hash(value.__dict__, depth + 1, ignore_params)
    else:
        return hashlib.md5("unknown".encode()).hexdigest()

file_cache

file_cache is a decorator that handles the caching logic for us.

@file_cache()
def search_codebase(
    cloned_github_repo,
    query,
):
    # ... take a long time ...
    # ... llm agent logic to search through the codebase ...
    return top_results

Without the cache, searching through the codebase using our LLM agent to get top_results takes 5 minutes - way too long if we're not actually testing it. Instead with file_cache, we just need to wait for deserialization of the pickled object - basically instantaneous for search results.

Wrapper

First we store our cache in /tmp/file_cache. This lets us remove the cache by simply deleting the directory (running rm -rf /tmp/file_cache). We can also selectively remove function calls using rm -rf /tmp/file_cache/search_codebase*.

def wrapper(*args, **kwargs):
    cache_dir = "/tmp/file_cache"
    os.makedirs(cache_dir, exist_ok=True)

Then we can create a cache key.

Cache Key Creation / Miss Conditions

We have another problem - we want to miss our cache under two conditions:

  1. The arguments to the function change - handled by recursive_hash
  2. The code changes

To handle 2. we used inspect.getsource(func) to add the function's source code to the hash, correctly missing the cache when the the code changes.

    func_source_code_hash = hash_code(inspect.getsource(func))
    args_names = func.__code__.co_varnames[: func.__code__.co_argcount]
    args_dict = dict(zip(args_names, args))
 
    # Remove ignored params
    kwargs_clone = kwargs.copy()
    for param in ignore_params:
        args_dict.pop(param, None)
        kwargs_clone.pop(param, None)
 
    # Create hash based on argument names, argument values, and function source code
    arg_hash = (
        recursive_hash(args_dict, ignore_params=ignore_params)
        + recursive_hash(kwargs_clone, ignore_params=ignore_params)
        + func_source_code_hash
    )
    cache_file = os.path.join(
        cache_dir, f"{func.__module__}_{func.__name__}_{arg_hash}.pickle"
    )

Cache hits and misses

Finally we check cache key existence and write to the cache in the case of a cache miss.

    try:
        # If cache exists, load and return it
        if os.path.exists(cache_file):
            if verbose:
                print("Used cache for function: " + func.__name__)
            with open(cache_file, "rb") as f:
                return pickle.load(f)
    except Exception:
        logger.info("Unpickling failed")
 
    # Otherwise, call the function and save its result to the cache
    result = func(*args, **kwargs)
    try:
        with open(cache_file, "wb") as f:
            pickle.dump(result, f)
    except Exception as e:
        logger.info(f"Pickling failed: {e}")
    return result

Conclusion

We hope this code is useful to you. We've found it to be a massive time saver when debugging LLM calls. We'd love to hear your feedback and contributions at https://github.com/sweepai/sweep (opens in a new tab)!