A Better Way to Mock Context Managers in Python

I came across an interesting way to setup mock objects the other day in python. I was trying to mock out a call to the database. I always have difficulties mocking context managers, because I so often forget that it is the return_value of __enter__ that needs to be patched. Anyway the code that I was trying to mock looked something like this.

In [2]:
from pymssql import connect

def load_data(server, database):
    with connect(server, database, user='', password='') as conn:
        with conn.cursor() as cursor:
            cursor.callproc('pLoadData')
            result = cursor.fetchone()
            # do some stuff with result
            return result

After some head scratching and several attempts I managed to get the mock setup to patch in my own result to fetchone.

In [3]:
from unittest.mock import patch

with patch('pymssql.connect') as connect:

    connect.return_value.__enter__.return_value \
        .cursor.return_value.__enter__.return_value \
            .fetchone.return_value = 42

    print(load_data('server', 'database'))
42

However this is really cumbersome, error prone and easy to forget. However because of the way Mock and MagicMock work (by creating new child mocks for attributes as they are accessed if they don't already exist) we can just invoke the functions in with blocks until we get to the call to fetchone. The dummy value still needs to be assigned to the return_value of fetchone like before. However this time without chaining all those __enter__ and return_value's. When the function under test is invoked the call en __enter__ mocks already exist, and are used by the code under test. I reckon this much more readable.

In [4]:
with patch('pymssql.connect') as connect:
    
    with connect() as conn:
        with conn.cursor() as cursor:
            cursor.fetchone.return_value = 42
            
    print(load_data('server', 'database'))
42

Comments

Comments powered by Disqus