Thursday, September 4, 2014

ASP.NET MVC Unit Testing Part 2: Faking the Database

Now that the production code is all set up, it's time to create the fakes that I'll inject for my unit tests.  Rather than connecting to the production database, I want my unit tests to use an in memory dataset so they will run lightning fast.  There are limitations to using in memory data (in effect substituting LINQ to Object for LINQ to Entities) which I will address in another post.




Faking the Database

I knew before I could write any tests, I would have to create a fake unit of work, which meant I also needed a fake generic repository.  But here I hit another snag, because the fakes I got from OdeToFood were not going to work very well (I know there was a good reason but it escapes me, probably because OdeToFood used a dictionary to hold the data, I couldn't fake the Find() method, and I know I had issues with IEnumerable and IQueryable)... 

Fortunately I stumbled upon Brent McKendricks blog on implementing a fake IDbSet and that proved to be just the ticket.  A little plug and play and I was rolling.

public class FakeDbSet<T> : IDbSet<T> where T : class
    {
        private readonly HashSet<T> _data;
        private readonly IQueryable _query;
        private int _identity = 1;
        private List<PropertyInfo> _keyProperties;
 
        private void GetKeyProperties()
        {
            _keyProperties = new List<PropertyInfo>();
            PropertyInfo[] properties = typeof(T).GetProperties();
            foreach (PropertyInfo property in properties)
            {
                foreach (Attribute attribute in property.GetCustomAttributes(true))
                {
                    if (attribute is KeyAttribute)
                    {
                        _keyProperties.Add(property);
                    }
                }
            }
 
            if (_keyProperties.Count == 0)
            {
                var idProperty = properties.FirstOrDefault(p => string.Equals(p.Name"id"StringComparison.InvariantCultureIgnoreCase));
                if (idProperty != null)
                    _keyProperties.Add(idProperty);
            }
 
            if (_keyProperties.Count == 0)
            {
                var idProperty = properties.FirstOrDefault(p => p.Name.Contains("_Id"|| p.Name.Contains("_ID"|| p.Name.Contains("_id"));
                if (idProperty != null)
                    _keyProperties.Add(idProperty);
            }
        }
 
        private void GenerateId(T entity)
        {
            // If non-composite integer key
            if (_keyProperties.Count == 1 && _keyProperties[0].PropertyType == typeof(Int32))
                _keyProperties[0].SetValue(entity_identity++null);
        }
 
        public FakeDbSet(IEnumerable<T> startData = null)
        {
            GetKeyProperties();
            _data = (startData != null ? new HashSet<T>(startData) : new HashSet<T>());
            _query = _data.AsQueryable();
        }
 
        public virtual T Find(params object[] keyValues)
        {
            if (keyValues.Length != _keyProperties.Count)
                throw new ArgumentException("Incorrect number of keys passed to find method");
 
            IQueryable<T> keyQuery = this.AsQueryable<T>();
            for (int i = 0i < keyValues.Lengthi++)
            {
                var x = i// nested linq
                keyQuery = keyQuery.Where(entity => _keyProperties[x].GetValue(entitynull).Equals(keyValues[x]));
            }
 
            return keyQuery.SingleOrDefault();
        }
 
        public T Add(T item)
        {
            GenerateId(item);
            _data.Add(item);
            return item;
        }
 
        public T Remove(T item)
        {
            _data.Remove(item);
            return item;
        }
 
        public T Attach(T item)
        {
            _data.Add(item);
            return item;
        }
 
        public void Detach(T item)
        {
            _data.Remove(item);
        }
 
        Type IQueryable.ElementType
        {
            get { return _query.ElementType; }
        }
 
        Expression IQueryable.Expression
        {
            get { return _query.Expression; }
        }
 
        IQueryProvider IQueryable.Provider
        {
            get { return _query.Provider; }
        }
 
        System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
        {
            return _data.GetEnumerator();
        }
 
        IEnumerator<T> IEnumerable<T>.GetEnumerator()
        {
            return _data.GetEnumerator();
        }
 
        public T Create()
        {
            return Activator.CreateInstance<T>();
        }
 
        public ObservableCollection<T> Local
        {
            get
            {
                return new ObservableCollection<T>(_data);
            }
        }
 
        public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, T
        {
            return Activator.CreateInstance<TDerivedEntity>();
        }
    }


The one modification I ended up making was to the GetKeyProperties() method.  It turned out that the [Key] attribute wasn't applied to the primary key in the object classes, so Find() was still breaking.  The comments included a blurb about basically looking up the "Id" field in the properties, which I adapted to look for any property ending in "_Id" or similar since this was the pattern our database followed.

The generic repository ended up being a kind of mutation of the OdeToFood fake, using the fake IDbSet to handle queries but following the OdeToFood pattern on the other CRUD operations.  This probably wasn't the best way to handle this, but I knew how the OdeToFood version of the fake worked and how it was used in tests, so I stuck with what I knew.  Someday I will probably look back and cringe, but here it is:

public class FakeGenericRepository<TEntity> : IGenericRepository<TEntity> where TEntity : class
{
 
    public TEntity GetById(object id)
    {
        return dbSet.Find(id);
    }
 
    public IQueryable<TEntity> GetAll()
    {
        return dbSet;
    }
 
    public void Insert(TEntity entity)
    {
        Added.Add(entity);
    }
 
    public void Update(TEntity entity)
    {
        Updated.Add(entity);
    }
 
    public void Delete(TEntity entity)
    {
        Removed.Add(entity);
    }
 
    public void Delete(object id)
    {
        var entity = dbSet.Find(id);
        Removed.Add(entity);
    }
 
    public IEnumerator<TEntity> GetEnumerator()
    {
        return dbSet.AsEnumerable().GetEnumerator();
    }
 
    IEnumerator IEnumerable.GetEnumerator()
    {
        return dbSet.AsEnumerable().GetEnumerator();
    }
 
    public Type ElementType
    {
        get { return dbSet.AsQueryable().ElementType; }
    }
 
    public System.Linq.Expressions.Expression Expression
    {
        get { return dbSet.AsQueryable().Expression; }
    }
 
    public IQueryProvider Provider
    {
        get { return dbSet.AsQueryable().Provider; }
    }
 
    public void AddSet(IQueryable<TEntity> objects)
    {
        foreach (var item in objects)
        {
            dbSet.Add(item);
        }
    }
 
    public FakeDbSet<TEntity> dbSet = new FakeDbSet<TEntity>();
    public List<object> Added = new List<object>();
    public List<object> Updated = new List<object>();
    public List<object> Removed = new List<object>();
    public bool Saved = false;
}


It's not pretty but it got the job done.  Finally, the FakeUnitOfWork was practically a line for line clone of the SQL version, just using the fakes instead of the live database.  The SaveChanges() implementation just flips the "Saved" bit on every repo.

class FakeUnitOfWork : IUnitOfWork
    {
        private FakeGenericRepository<Application> applications;
        private FakeGenericRepository<AspNetRole> aspNetRoles;
        private FakeGenericRepository<AspNetUserClaim> aspNetUserClaims;
        private FakeGenericRepository<AspNetUserLogin> aspNetUserLogins;
        private FakeGenericRepository<AspNetUser> aspNetUsers;
        private FakeGenericRepository<C__MigrationHistory> c__MigrationHistory;
        private FakeGenericRepository<LuBranch> luBranches;
        private FakeGenericRepository<LuDegree> luDegrees;
        private FakeGenericRepository<LuEligibility> luEligibilities;
        private FakeGenericRepository<LuGrade> luGrades;
        private FakeGenericRepository<LuSchool> luSchools;
        private FakeGenericRepository<LuState> luStates;
        private FakeGenericRepository<LuUnit> luUnits;
        private FakeGenericRepository<Term> terms;
        private FakeGenericRepository<UnitPersonnel> unitPersonnels;
 
        private BranchUnitRepository branchRepository;
        private ApplicationTermsRepository termRepository;
 
        public IGenericRepository<Application> Applications
        {
            get
            {
                if (this.applications == null)
                {
                    this.applications = new FakeGenericRepository<Application>();
                }
                return applications;
            }
        }
 
        public IGenericRepository<AspNetRole> AspNetRoles
        {
            get
            {
 
                if (this.aspNetRoles == null)
                {
                    this.aspNetRoles = new FakeGenericRepository<AspNetRole>();
                }
                return aspNetRoles;
            }
        }

        ......
 
        //other repos
        public IBranchUnitRepository _branchRepository
        {
            get
            {
                if (this.branchRepository == null)
                {
                    this.branchRepository = new BranchUnitRepository(this);
                }
                return branchRepository;
            }
        }
 
        public IApplicationTermRepository _termRepository
        {
            get
            {
                if (this.termRepository == null)
                {
                    this.termRepository = new ApplicationTermsRepository(this);
                }
                return termRepository;
            }
        }
 
        public void Commit()
        {
            if(applications != nullapplications.Saved = true;
            if (aspNetRoles != nullaspNetRoles.Saved = true;
            if (aspNetUserClaims != nullaspNetUserClaims.Saved = true;
            if (aspNetUserLogins != nullaspNetUserLogins.Saved = true;
            if (aspNetUsers != nullaspNetUsers.Saved = true;
            if (c__MigrationHistory != nullc__MigrationHistory.Saved = true;
            if (luBranches != nullluBranches.Saved = true;
            if (luDegrees != nullluDegrees.Saved = true;
            if (luEligibilities != nullluEligibilities.Saved = true;
            if (luGrades != nullluGrades.Saved = true;
            if (luSchools != nullluSchools.Saved = true;
            if (luStates != nullluStates.Saved = true;
            if (luUnits != nullluUnits.Saved = true;
            if (terms != nullterms.Saved = true;
            if (unitPersonnels != nullunitPersonnels.Saved = true;
        }
    }


The only thing left to do was to stuff some fake data in there.  This is where the helper function AddSet() on the fake repository comes in very handy, though I have a feeling that all the casts I have to do is a sign I'm not doing this right (that goes for the tests too...).  It got the job done and I'm sure it's only a matter of time before I figure out how to fix it. Ce la vie...

So here is the fake data:

class TestData
    {
 
        public static FakeUnitOfWork addTestData(FakeUnitOfWork db)
        {
            ((FakeGenericRepository<LuBranch>)db.LuBranches).AddSet(TestData.LuBranches);
            ((FakeGenericRepository<LuEligibility>)db.LuEligibilities).AddSet(TestData.LuEligibilitys);
            ((FakeGenericRepository<LuGrade>)db.LuGrades).AddSet(TestData.LuGrades);
            ((FakeGenericRepository<LuSchool>)db.LuSchools).AddSet(TestData.LuSchools);
            ((FakeGenericRepository<LuState>)db.LuStates).AddSet(TestData.LuStates);
            ((FakeGenericRepository<LuUnit>)db.LuUnits).AddSet(TestData.LuUnits);
            ((FakeGenericRepository<AspNetUser>)db.AspNetUsers).AddSet(TestData.AspNetUsers);
            ((FakeGenericRepository<Application>)db.Applications).AddSet(TestData.Applications);
            ((FakeGenericRepository<UnitPersonnel>)db.UnitPersonnels).AddSet(TestData.UnitPersonnels);
            ((FakeGenericRepository<Term>)db.Terms).AddSet(TestData.Terms);
            ((FakeGenericRepository<LuDegree>)db.LuDegrees).AddSet(TestData.LuDegrees);
            return db;
        }
 
        public static IQueryable<LuBranch> LuBranches
        {
            get
            {
                var branches = new List<LuBranch>();
                for (var i = 0i < 10i++)
                {
                    LuBranch branch = new LuBranch();
                    //branch.Branch_Id = i;
                    branch.Descr = "test branch " + i.ToString();
                    branches.Add(branch);
                }
                return branches.AsQueryable();
            }
        }
 
        public static IQueryable<LuEligibility> LuEligibilitys
        {
            get
            {
                var eligs = new List<LuEligibility>();
                for (var i = 0i < 10i++)
                {
                    LuEligibility elig = new LuEligibility();
                    //elig.Eligibility_Id = i;
                    elig.Descr = "test eligibility " + i.ToString();
                    eligs.Add(elig);
                }
                return eligs.AsQueryable();
            }
        }
 
        public static IQueryable<LuGrade> LuGrades
        {
            get
            {
                var grades = new List<LuGrade>();
                for (var i = 0i < 10i++)
                {
                    LuGrade grade = new LuGrade();
                    //grade.Grade_Id = i;
                    grade.Descr = "test grade " + i.ToString();
                    grades.Add(grade);
                }
                return grades.AsQueryable();
            }
        }
 
        public static IQueryable<LuSchool> LuSchools
        {
            get
            {
                var schools = new List<LuSchool>();
                for (var i = 0i < 10i++)
                {
                    LuSchool school = new LuSchool();
                    //school.School_ID = i;
                    school.Descr = "test school " + i.ToString();
                    school.Is_AU_ABC = true;
                    schools.Add(school);
                }
                return schools.AsQueryable();
            }
        }
 
        public static IQueryable<LuDegree> LuDegrees
        {
            get
            {
                var degrees = new List<LuDegree>();
                for (var i = 0i < 10i++)
                {
                    LuDegree degree = new LuDegree();
                    //degree.Degree_Id = i;
                    degree.Descr = "test degree " + i.ToString();
                    degrees.Add(degree);
                }
                return degrees.AsQueryable();
            }
        }
 
        public static IQueryable<LuState> LuStates
        {
            get
            {
                var states = new List<LuState>();
                for (var i = 0i < 10i++)
                {
                    LuState state = new LuState();
                    //state.State_Id = i;
                    state.Descr = "test state " + i.ToString();
                    states.Add(state);
                }
                return states.AsQueryable();
            }
        }
 
        public static IQueryable<LuUnit> LuUnits
        {
            get
            {
                var units = new List<LuUnit>();
                for (var i = 0i < 10i++)
                {
                    LuUnit unit = new LuUnit();
                    //unit.Unit_Id = i;
                    unit.Branch_id = (int)Math.Round((double)i / 2);
                    unit.Descr = "test unit " + i.ToString();
                    units.Add(unit);
                }
                return units.AsQueryable();
            }
        }
 
        public static IQueryable<AspNetUser> AspNetUsers
        {
            get
            {
                var users = new List<AspNetUser>();
                for (var i = 0i < 10i++)
                {
                    AspNetUser user = new AspNetUser();
                    user.Id = i.ToString();
                    user.FirstName = "test";
                    user.LastName = "user" + i.ToString();
                    user.Email = "testuser" + i.ToString() + "@blah.net";
                    users.Add(user);
                }
                return users.AsQueryable();
            }
        }
 
        public static IQueryable<UnitPersonnel> UnitPersonnels
        {
            get
            {
                var ups = new List<UnitPersonnel>();
                for (var i = 0i < 10i++)
                {
                    UnitPersonnel up = new UnitPersonnel();
                    //up.UnitPersonnel_Id = i;
                    up.Unit_Id = (int)Math.Round((double)i / 2+ 1;
                    up.AspNetUserID = (Math.Round((double)i / 2)).ToString();
                    ups.Add(up);
                }
                return ups.AsQueryable();
            }
        }
 
        public static IQueryable<Term> Terms
        {
            get
            {
                var terms = new List<Term>();
                for (var i = 0i < 10i++)
                {
                    Term t = new Term();
                    switch(i%3)
                    {
                        case 0:
                        t.Descr = "Fall 201" + Math.Ceiling((double)(i / 3)).ToString();
                            break;
                        case 1:
                        t.Descr = "Spring 201" + Math.Ceiling((double)(i / 3)).ToString();
                            break;
                        case 2:
                        t.Descr = "Summer 201" + Math.Ceiling((double)(i / 3)).ToString();
                            break;
                    }
                    if(i<5)
                    { t.Is_Visible = true; }
                    else
                    { t.Is_Visible = false; }
                    
                    terms.Add(t);
                }
                return terms.AsQueryable();
            }
        }
 
        public static IQueryable<Application> Applications
        {
            get
            {
                var apps = new List<Application>();
 
                for (var i = 0i < 10i++)
                {
                    bool rotc = false;
                    if (i % 3 == 0)
                        rotc = true;
 
                    apps.Add(new Application()
                    {
                        Application_Id = i,
                        AspNetUsersID = i.ToString(),
                        Branch_Id = (int)Math.Round((double)i / 2+ 1,
                        Degree_Id = (int)Math.Round((double)i / 2+ 1,
                        City = "Cheyenne",
                        Day_Phone = "555-555" + i.ToString(),
                        Email = "testuser" + i.ToString() + "@blah.net",
                        Military_email = "testuser" + i.ToString() + "@blah.mil",
                        Grade_Id = i + 1,
                        Eligibility_Id = i + 1,
                        First_Name = "Test",
                        Last_Name = "User" + i.ToString(),
                        Home_Phone = "555-123" + i.ToString(),
                        Is_ROTC = rotc,
                        Last_Four_SSN = "999" + i.ToString(),
                        Modified_Date = DateTime.Now,
                        School_Id = (int)Math.Round((double)i / 2+ 1,
                        Term_Id = (int)Math.Round((double)i / 2+ 1,
                        State_Id = 1,
                        Unit_Id = 1,
                        Street = "123" + i.ToString() + " Test St.",
                        Zip = "90210-5555"
                    });
 
                    apps.Add(new Application()
                    {
                        Application_Id = i + 10,
                        AspNetUsersID = i.ToString(),
                        Branch_Id = i + 1,
                        Degree_Id = i + 1,
                        City = "Cheyenne",
                        Day_Phone = "555-555" + i.ToString(),
                        Email = "testuser" + i.ToString() + "@blah.net",
                        Military_email = "testuser" + i.ToString() + "@blah.mil",
                        Grade_Id = 1,
                        Eligibility_Id = i + 1,
                        First_Name = "Test",
                        Last_Name = "User" + i.ToString(),
                        Home_Phone = "555-123" + i.ToString(),
                        Is_ROTC = rotc,
                        Last_Four_SSN = "999" + i.ToString(),
                        Modified_Date = DateTime.Today.AddDays(-60),
                        School_Id = (int)Math.Round((double)i / 2+ 1,
                        Term_Id = (int)Math.Round((double)i / 2+ 1,
                        State_Id = i + 1,
                        Unit_Id = (int)Math.Round((double)i / 2+ 1,
                        Street = "123" + i.ToString() + " Test St.",
                        Zip = "90210-5555",
                        Verified_By = ((int)Math.Round((double)i / 2+ 1).ToString(),
                        Verified_Date = DateTime.Now
                    });
                }
 
                return apps.AsQueryable();
            }
        }
    }

So I'm all set right?  Well I thought so until the first time I got a NullReferenceException thrown by a call to User.Identy.GetId()... DOH!  Next I'll need to look into faking parts of the environment, specifically the HttpContext.


No comments:

Post a Comment