// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information. namespace WixToolset.Dtf.WindowsInstaller.Linq { using System; using System.IO; using System.Collections; using System.Collections.Generic; using System.Reflection; using System.Linq; using System.Linq.Expressions; using System.Diagnostics.CodeAnalysis; /// /// Represents one table in a LINQ-queryable Database. /// /// type that represents one record in the table /// /// This class is the primary gateway to all LINQ to MSI query functionality. /// The TRecord generic parameter may be the general /// class, or a specialized subclass of QRecord. /// [SuppressMessage("Microsoft.Naming", "CA1710:IdentifiersShouldHaveCorrectSuffix")] public sealed class QTable : IOrderedQueryable, IQueryProvider where TRecord : QRecord, new() { private QDatabase db; private TableInfo tableInfo; /// /// Infers the name of the table this instance will be /// associated with. /// /// table name /// /// The table name is retrieved from a DatabaseTableAttribute /// on the record type if it exists; otherwise the name is /// derived from the name of the record type itself. /// (An optional underscore suffix on the record type name is dropped.) /// private static string InferTableName() { foreach (DatabaseTableAttribute attr in typeof(TRecord).GetCustomAttributes( typeof(DatabaseTableAttribute), false)) { string tableName = attr.Table; if (!String.IsNullOrEmpty(tableName)) { return tableName; } } string recordTypeName = typeof(TRecord).Name; if (recordTypeName[recordTypeName.Length - 1] == '_') { return recordTypeName.Substring(0, recordTypeName.Length - 1); } else { return recordTypeName; } } /// /// Creates a new QTable, inferring the table name /// from the name of the record type parameter. /// /// database that contains the table public QTable(QDatabase db) : this(db, InferTableName()) { } /// /// Creates a new QTable with an explicit table name. /// /// database that contains the table /// name of the table public QTable(QDatabase db, string table) { if (db == null) { throw new ArgumentNullException("db"); } if (String.IsNullOrEmpty(table)) { throw new ArgumentNullException("table"); } this.db = db; this.tableInfo = db.Tables[table]; if (this.tableInfo == null) { throw new ArgumentException( "Table does not exist in database: " + table); } } /// /// Gets schema information about the table. /// public TableInfo TableInfo { get { return this.tableInfo; } } /// /// Gets the database this table is associated with. /// public QDatabase Database { get { return this.db; } } /// /// Enumerates over all records in the table. /// /// public IEnumerator GetEnumerator() { string query = this.tableInfo.SqlSelectString; TextWriter log = this.db.Log; if (log != null) { log.WriteLine(); log.WriteLine(query); } using (View view = db.OpenView(query)) { view.Execute(); ColumnCollection columns = this.tableInfo.Columns; int columnCount = columns.Count; bool[] isBinary = new bool[columnCount]; for (int i = 0; i < isBinary.Length; i++) { isBinary[i] = columns[i].Type == typeof(System.IO.Stream); } foreach (Record rec in view) using (rec) { string[] values = new string[columnCount]; for (int i = 0; i < values.Length; i++) { values[i] = isBinary[i] ? "[Binary Data]" : rec.GetString(i + 1); } TRecord trec = new TRecord(); trec.Database = this.Database; trec.TableInfo = this.TableInfo; trec.Values = values; trec.Exists = true; yield return trec; } } } IEnumerator IEnumerable.GetEnumerator() { return ((IEnumerable) this).GetEnumerator(); } IQueryable IQueryProvider.CreateQuery(Expression expression) { if (expression == null) { throw new ArgumentNullException("expression"); } Query q = new Query(this.Database, expression); MethodCallExpression methodCallExpression = (MethodCallExpression) expression; string methodName = methodCallExpression.Method.Name; if (methodName == "Where") { LambdaExpression argumentExpression = (LambdaExpression) ((UnaryExpression) methodCallExpression.Arguments[1]).Operand; q.BuildQuery(this.TableInfo, argumentExpression); } else if (methodName == "OrderBy") { LambdaExpression argumentExpression = (LambdaExpression) ((UnaryExpression) methodCallExpression.Arguments[1]).Operand; q.BuildSequence(this.TableInfo, argumentExpression); } else if (methodName == "Select") { LambdaExpression argumentExpression = (LambdaExpression) ((UnaryExpression) methodCallExpression.Arguments[1]).Operand; q.BuildNullQuery(this.TableInfo, typeof(TRecord), argumentExpression); q.BuildProjection(null, argumentExpression); } else if (methodName == "Join") { ConstantExpression constantExpression = (ConstantExpression) methodCallExpression.Arguments[1]; IQueryable inner = (IQueryable) constantExpression.Value; q.PerformJoin( this.TableInfo, typeof(TRecord), inner, GetJoinLambda(methodCallExpression.Arguments[2]), GetJoinLambda(methodCallExpression.Arguments[3]), GetJoinLambda(methodCallExpression.Arguments[4])); } else { throw new NotSupportedException( "Query operation not supported: " + methodName); } return q; } private static LambdaExpression GetJoinLambda(Expression expresion) { UnaryExpression unaryExpression = (UnaryExpression) expresion; return (LambdaExpression) unaryExpression.Operand; } IQueryable IQueryProvider.CreateQuery(Expression expression) { return ((IQueryProvider) this).CreateQuery(expression); } TResult IQueryProvider.Execute(Expression expression) { throw new NotSupportedException( "Direct method calls not supported -- use AsEnumerable() instead."); } object IQueryProvider.Execute(Expression expression) { throw new NotSupportedException( "Direct method calls not supported -- use AsEnumerable() instead."); } IQueryProvider IQueryable.Provider { get { return this; } } Type IQueryable.ElementType { get { return typeof(TRecord); } } Expression IQueryable.Expression { get { return Expression.Constant(this); } } /// /// Creates a new record that can be inserted into this table. /// /// a record with all fields initialized to null /// /// Primary keys and required fields must be filled in with /// non-null values before the record can be inserted. /// The record is tied to this table in this database; /// it cannot be inserted into another table or database. /// public TRecord NewRecord() { TRecord rec = new TRecord(); rec.Database = this.Database; rec.TableInfo = this.TableInfo; IList values = new List(this.TableInfo.Columns.Count); for (int i = 0; i < this.TableInfo.Columns.Count; i++) { values.Add(null); } rec.Values = values; return rec; } } }