252 lines
8.3 KiB
C#
252 lines
8.3 KiB
C#
// <copyright file="BaseRepository.cs" company="alveus.dev">
|
|
// Copyright (c) alveus.dev. All rights reserved. Licensed under the MIT License.
|
|
// </copyright>
|
|
|
|
using System.Data;
|
|
using System.Reflection;
|
|
using System.Text;
|
|
using Astral.Core.Attributes.EntityAnnotation;
|
|
using Astral.Core.Exceptions;
|
|
using Astral.Core.Infrastructure;
|
|
using Astral.Core.RepositoryInterfaces;
|
|
using Dapper;
|
|
|
|
namespace Astral.DAL.Repositories;
|
|
|
|
/// <inheritdoc />
|
|
public class BaseRepository<TEntity, TKeyType> : IGenericRepository<TEntity, TKeyType>
|
|
where TEntity : class
|
|
{
|
|
/// <summary>
|
|
/// Instance of <see cref="IDbConnectionProvider" />.
|
|
/// </summary>
|
|
private readonly IDbConnectionProvider _db;
|
|
|
|
/// <summary>
|
|
/// Initializes a new instance of the <see cref="BaseRepository{TEntity, TKeyType}" /> class.
|
|
/// </summary>
|
|
/// <param name="db">Instance of <see cref="IDbConnectionProvider" />.</param>
|
|
protected BaseRepository(IDbConnectionProvider db)
|
|
{
|
|
_db = db;
|
|
SetupRepository();
|
|
}
|
|
|
|
/// <summary>
|
|
/// Table name.
|
|
/// </summary>
|
|
protected string Table { get; set; }
|
|
|
|
/// <summary>
|
|
/// Primary key column name.
|
|
/// </summary>
|
|
private string PrimaryKeyColumn { get; set; }
|
|
|
|
/// <summary>
|
|
/// Whether the primary key should be treated as auto-generated by the database.
|
|
/// </summary>
|
|
private bool AutoPrimaryKey { get; set; }
|
|
|
|
/// <summary>
|
|
/// Generated select all query string.
|
|
/// </summary>
|
|
private string SelectAllQuery { get; set; }
|
|
|
|
/// <summary>
|
|
/// Generated find by id query string.
|
|
/// </summary>
|
|
private string FindByIdQuery { get; set; }
|
|
|
|
/// <summary>
|
|
/// Generated insert query string.
|
|
/// </summary>
|
|
private string InsertQuery { get; set; }
|
|
|
|
/// <summary>
|
|
/// Generated update query string.
|
|
/// </summary>
|
|
private string UpdateQuery { get; set; }
|
|
|
|
/// <summary>
|
|
/// Generated delete query string.
|
|
/// </summary>
|
|
private string DeleteQuery { get; set; }
|
|
|
|
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
|
|
public async Task<TEntity> FindByIdAsync(TKeyType id)
|
|
{
|
|
return await WithDatabaseAsync(async connection =>
|
|
{
|
|
var result = await connection.QueryFirstOrDefaultAsync<TEntity>(FindByIdQuery, new
|
|
{
|
|
pId = id
|
|
});
|
|
|
|
return result;
|
|
});
|
|
}
|
|
|
|
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
|
|
public async Task<IEnumerable<TEntity>> GetAllAsync()
|
|
{
|
|
return await WithDatabaseAsync(async connection =>
|
|
{
|
|
var result = await connection.QueryAsync<TEntity>(SelectAllQuery);
|
|
|
|
return result;
|
|
});
|
|
}
|
|
|
|
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
|
|
public async Task<TKeyType> AddAsync(TEntity entity)
|
|
{
|
|
return await WithDatabaseAsync(async connection => await connection.QuerySingleAsync<TKeyType>(
|
|
InsertQuery, entity));
|
|
}
|
|
|
|
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
|
|
public async Task AddAsync(IEnumerable<TEntity> entities)
|
|
{
|
|
await WithDatabaseAsync(async connection =>
|
|
{
|
|
await connection.ExecuteAsync(
|
|
InsertQuery, entities);
|
|
});
|
|
}
|
|
|
|
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
|
|
public async Task<bool> UpdateAsync(TEntity entity)
|
|
{
|
|
return await WithDatabaseAsync(async connection => await connection.ExecuteAsync(UpdateQuery, entity) > 0);
|
|
}
|
|
|
|
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
|
|
public async Task DeleteAsync(TKeyType id)
|
|
{
|
|
await WithDatabaseAsync(async connection => await connection.ExecuteAsync(DeleteQuery, new
|
|
{
|
|
pId = id
|
|
}));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Establish a connection and perform the query function.
|
|
/// </summary>
|
|
/// <param name="query">Query method to execute.</param>
|
|
/// <typeparam name="T">The return type.</typeparam>
|
|
/// <returns>Result of query.</returns>
|
|
/// <exception cref="Exception">Thrown if an exception occurs.</exception>
|
|
protected async Task<T> WithDatabaseAsync<T>(Func<IDbConnection, Task<T>> query)
|
|
{
|
|
var connection = await _db.OpenConnectionAsync();
|
|
return await query(connection);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Establish a connection and perform the query function.
|
|
/// </summary>
|
|
/// <param name="query">Query method to execute.</param>
|
|
protected async Task WithDatabaseAsync(Func<IDbConnection, Task> query)
|
|
{
|
|
var connection = await _db.OpenConnectionAsync();
|
|
await query(connection);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Fetch a comma separated list of property names.
|
|
/// </summary>
|
|
/// <param name="excludeKey">True to exclude key property.</param>
|
|
/// <returns>Comma separated list of column names.</returns>
|
|
private static string GetPropertyNames(bool excludeKey = false)
|
|
{
|
|
var properties = typeof(TEntity).GetProperties()
|
|
.Where(p => !excludeKey || p.GetCustomAttribute<PrimaryKeyAttribute>() == null)
|
|
.Where(p => p.IsDefined(typeof(PrimaryKeyAttribute), true) ||
|
|
p.IsDefined(typeof(ColumnMappingAttribute), true));
|
|
|
|
var values = string.Join(", ", properties.Select(p => $"@{p.Name}"));
|
|
|
|
return values;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Setup repository.
|
|
/// </summary>
|
|
private void SetupRepository()
|
|
{
|
|
var entityType = typeof(TEntity);
|
|
var properties = entityType.GetProperties();
|
|
var propertiesExcludingKey = properties
|
|
.Where(p => p.GetCustomAttribute<PrimaryKeyAttribute>() == null);
|
|
|
|
// Set up table name.
|
|
var tableAttribute = entityType.GetCustomAttribute<TableMappingAttribute>();
|
|
if (tableAttribute is null)
|
|
{
|
|
Table = typeof(TEntity).Name + "s";
|
|
}
|
|
else
|
|
{
|
|
Table = tableAttribute.Name;
|
|
}
|
|
|
|
// Set up primary key.
|
|
var primaryKeyProperties = properties.Where(p => p.IsDefined(typeof(PrimaryKeyAttribute), true)).ToList();
|
|
|
|
if (primaryKeyProperties is null || primaryKeyProperties.Count == 0)
|
|
{
|
|
throw new PrimaryKeyMissingException(typeof(TEntity));
|
|
}
|
|
|
|
if (primaryKeyProperties.Count > 1)
|
|
{
|
|
throw new MultiplePrimaryKeysException(typeof(TEntity));
|
|
}
|
|
|
|
var primaryKeyProperty = primaryKeyProperties.First();
|
|
var primaryKey = primaryKeyProperty.GetCustomAttribute<PrimaryKeyAttribute>();
|
|
var primaryKeyCol = primaryKeyProperty.GetCustomAttribute<ColumnMappingAttribute>();
|
|
|
|
if (primaryKeyCol is null)
|
|
{
|
|
throw new MissingColumnAttributeException(typeof(TEntity));
|
|
}
|
|
|
|
PrimaryKeyColumn = primaryKeyCol!.Name;
|
|
AutoPrimaryKey = primaryKey!.AutoGenerated;
|
|
|
|
SelectAllQuery = $"SELECT * FROM {Table};";
|
|
FindByIdQuery = $"SELECT * FROM {Table} WHERE {PrimaryKeyColumn} = @pId;";
|
|
|
|
// Add query.
|
|
var columns = string.Join(", ", properties
|
|
.Where(p => !AutoPrimaryKey || !p.IsDefined(typeof(PrimaryKeyAttribute)))
|
|
.Where(p => p.IsDefined(typeof(ColumnMappingAttribute)))
|
|
.Select(p => p.GetCustomAttribute<ColumnMappingAttribute>() !.Name));
|
|
|
|
var propertyNames = GetPropertyNames(AutoPrimaryKey);
|
|
|
|
var insertQuery = new StringBuilder();
|
|
insertQuery.Append(
|
|
$"INSERT INTO {Table} ({columns}) VALUES ({propertyNames}) RETURNING {PrimaryKeyColumn};");
|
|
|
|
InsertQuery = insertQuery.ToString();
|
|
|
|
// Update query.
|
|
var updateQuery = new StringBuilder();
|
|
updateQuery.Append($"UPDATE {Table} SET");
|
|
foreach (var property in propertiesExcludingKey.Where(p => p.IsDefined(typeof(ColumnMappingAttribute))))
|
|
{
|
|
var columnAttr = property.GetCustomAttribute<ColumnMappingAttribute>();
|
|
var propertyName = property.Name;
|
|
updateQuery.Append($" {columnAttr!.Name} = @{propertyName},");
|
|
}
|
|
|
|
updateQuery.Remove(updateQuery.Length - 1, 1);
|
|
updateQuery.Append($" WHERE {PrimaryKeyColumn} = @{primaryKeyProperty.Name};");
|
|
UpdateQuery = updateQuery.ToString();
|
|
|
|
// Delete query.
|
|
DeleteQuery = $"DELETE FROM {Table} WHERE {PrimaryKeyColumn} = @pId;";
|
|
}
|
|
}
|