astral-api/Astral.DAL/Repositories/BaseRepository.cs
2024-12-11 20:36:30 +00:00

252 lines
8.3 KiB
C#

// <copyright file="BaseRepository.cs" company="alveus.dev">
// Copyright (c) alveus.dev. All rights reserved. Licensed under the MIT License.
// </copyright>
using System.Data;
using System.Reflection;
using System.Text;
using Astral.Core.Attributes.EntityAnnotation;
using Astral.Core.Exceptions;
using Astral.Core.Infrastructure;
using Astral.Core.RepositoryInterfaces;
using Dapper;
namespace Astral.DAL.Repositories;
/// <inheritdoc />
public class BaseRepository<TEntity, TKeyType> : IGenericRepository<TEntity, TKeyType>
where TEntity : class
{
/// <summary>
/// Instance of <see cref="IDbConnectionProvider" />.
/// </summary>
private readonly IDbConnectionProvider _db;
/// <summary>
/// Initializes a new instance of the <see cref="BaseRepository{TEntity, TKeyType}" /> class.
/// </summary>
/// <param name="db">Instance of <see cref="IDbConnectionProvider" />.</param>
protected BaseRepository(IDbConnectionProvider db)
{
_db = db;
SetupRepository();
}
/// <summary>
/// Table name.
/// </summary>
protected string Table { get; set; }
/// <summary>
/// Primary key column name.
/// </summary>
private string PrimaryKeyColumn { get; set; }
/// <summary>
/// Whether the primary key should be treated as auto-generated by the database.
/// </summary>
private bool AutoPrimaryKey { get; set; }
/// <summary>
/// Generated select all query string.
/// </summary>
private string SelectAllQuery { get; set; }
/// <summary>
/// Generated find by id query string.
/// </summary>
private string FindByIdQuery { get; set; }
/// <summary>
/// Generated insert query string.
/// </summary>
private string InsertQuery { get; set; }
/// <summary>
/// Generated update query string.
/// </summary>
private string UpdateQuery { get; set; }
/// <summary>
/// Generated delete query string.
/// </summary>
private string DeleteQuery { get; set; }
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
public async Task<TEntity> FindByIdAsync(TKeyType id)
{
return await WithDatabaseAsync(async connection =>
{
var result = await connection.QueryFirstOrDefaultAsync<TEntity>(FindByIdQuery, new
{
pId = id
});
return result;
});
}
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
public async Task<IEnumerable<TEntity>> GetAllAsync()
{
return await WithDatabaseAsync(async connection =>
{
var result = await connection.QueryAsync<TEntity>(SelectAllQuery);
return result;
});
}
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
public async Task<TKeyType> AddAsync(TEntity entity)
{
return await WithDatabaseAsync(async connection => await connection.QuerySingleAsync<TKeyType>(
InsertQuery, entity));
}
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
public async Task AddAsync(IEnumerable<TEntity> entities)
{
await WithDatabaseAsync(async connection =>
{
await connection.ExecuteAsync(
InsertQuery, entities);
});
}
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
public async Task<bool> UpdateAsync(TEntity entity)
{
return await WithDatabaseAsync(async connection => await connection.ExecuteAsync(UpdateQuery, entity) > 0);
}
/// <inheritdoc cref="IGenericRepository{TEntity,TKeyType}" />
public async Task DeleteAsync(TKeyType id)
{
await WithDatabaseAsync(async connection => await connection.ExecuteAsync(DeleteQuery, new
{
pId = id
}));
}
/// <summary>
/// Establish a connection and perform the query function.
/// </summary>
/// <param name="query">Query method to execute.</param>
/// <typeparam name="T">The return type.</typeparam>
/// <returns>Result of query.</returns>
/// <exception cref="Exception">Thrown if an exception occurs.</exception>
protected async Task<T> WithDatabaseAsync<T>(Func<IDbConnection, Task<T>> query)
{
var connection = await _db.OpenConnectionAsync();
return await query(connection);
}
/// <summary>
/// Establish a connection and perform the query function.
/// </summary>
/// <param name="query">Query method to execute.</param>
protected async Task WithDatabaseAsync(Func<IDbConnection, Task> query)
{
var connection = await _db.OpenConnectionAsync();
await query(connection);
}
/// <summary>
/// Fetch a comma separated list of property names.
/// </summary>
/// <param name="excludeKey">True to exclude key property.</param>
/// <returns>Comma separated list of column names.</returns>
private static string GetPropertyNames(bool excludeKey = false)
{
var properties = typeof(TEntity).GetProperties()
.Where(p => !excludeKey || p.GetCustomAttribute<PrimaryKeyAttribute>() == null)
.Where(p => p.IsDefined(typeof(PrimaryKeyAttribute), true) ||
p.IsDefined(typeof(ColumnMappingAttribute), true));
var values = string.Join(", ", properties.Select(p => $"@{p.Name}"));
return values;
}
/// <summary>
/// Setup repository.
/// </summary>
private void SetupRepository()
{
var entityType = typeof(TEntity);
var properties = entityType.GetProperties();
var propertiesExcludingKey = properties
.Where(p => p.GetCustomAttribute<PrimaryKeyAttribute>() == null);
// Set up table name.
var tableAttribute = entityType.GetCustomAttribute<TableMappingAttribute>();
if (tableAttribute is null)
{
Table = typeof(TEntity).Name + "s";
}
else
{
Table = tableAttribute.Name;
}
// Set up primary key.
var primaryKeyProperties = properties.Where(p => p.IsDefined(typeof(PrimaryKeyAttribute), true)).ToList();
if (primaryKeyProperties is null || primaryKeyProperties.Count == 0)
{
throw new PrimaryKeyMissingException(typeof(TEntity));
}
if (primaryKeyProperties.Count > 1)
{
throw new MultiplePrimaryKeysException(typeof(TEntity));
}
var primaryKeyProperty = primaryKeyProperties.First();
var primaryKey = primaryKeyProperty.GetCustomAttribute<PrimaryKeyAttribute>();
var primaryKeyCol = primaryKeyProperty.GetCustomAttribute<ColumnMappingAttribute>();
if (primaryKeyCol is null)
{
throw new MissingColumnAttributeException(typeof(TEntity));
}
PrimaryKeyColumn = primaryKeyCol!.Name;
AutoPrimaryKey = primaryKey!.AutoGenerated;
SelectAllQuery = $"SELECT * FROM {Table};";
FindByIdQuery = $"SELECT * FROM {Table} WHERE {PrimaryKeyColumn} = @pId;";
// Add query.
var columns = string.Join(", ", properties
.Where(p => !AutoPrimaryKey || !p.IsDefined(typeof(PrimaryKeyAttribute)))
.Where(p => p.IsDefined(typeof(ColumnMappingAttribute)))
.Select(p => p.GetCustomAttribute<ColumnMappingAttribute>() !.Name));
var propertyNames = GetPropertyNames(AutoPrimaryKey);
var insertQuery = new StringBuilder();
insertQuery.Append(
$"INSERT INTO {Table} ({columns}) VALUES ({propertyNames}) RETURNING {PrimaryKeyColumn};");
InsertQuery = insertQuery.ToString();
// Update query.
var updateQuery = new StringBuilder();
updateQuery.Append($"UPDATE {Table} SET");
foreach (var property in propertiesExcludingKey.Where(p => p.IsDefined(typeof(ColumnMappingAttribute))))
{
var columnAttr = property.GetCustomAttribute<ColumnMappingAttribute>();
var propertyName = property.Name;
updateQuery.Append($" {columnAttr!.Name} = @{propertyName},");
}
updateQuery.Remove(updateQuery.Length - 1, 1);
updateQuery.Append($" WHERE {PrimaryKeyColumn} = @{primaryKeyProperty.Name};");
UpdateQuery = updateQuery.ToString();
// Delete query.
DeleteQuery = $"DELETE FROM {Table} WHERE {PrimaryKeyColumn} = @pId;";
}
}