// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information. namespace WixToolset.Dtf.MakeSfxCA { using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Reflection; using System.Security; using System.Text; using WixToolset.Dtf.Compression; using WixToolset.Dtf.Compression.Cab; using WixToolset.Dtf.Resources; using ResourceCollection = WixToolset.Dtf.Resources.ResourceCollection; /// /// Command-line tool for building self-extracting custom action packages. /// Appends cabbed CA binaries to SfxCA.dll and fixes up the result's /// entry-points and file version to look like the CA module. /// public static class MakeSfxCA { private const string REQUIRED_WI_ASSEMBLY = "WixToolset.Dtf.WindowsInstaller.dll"; private static TextWriter log; /// /// Prints usage text for the tool. /// /// Console text writer. private static void Usage(TextWriter w) { w.WriteLine("WiX Toolset custom action packager version {0}", Assembly.GetExecutingAssembly().GetName().Version); w.WriteLine("Copyright (C) .NET Foundation and contributors. All rights reserved."); w.WriteLine(); w.WriteLine("Usage: WixToolset.Dtf.MakeSfxCA [-v] SfxCA.dll [support files ...]"); w.WriteLine(); w.WriteLine("Makes a self-extracting managed MSI CA or UI DLL package."); w.WriteLine("Support files must include " + MakeSfxCA.REQUIRED_WI_ASSEMBLY); w.WriteLine("Support files optionally include CustomAction.config/EmbeddedUI.config"); } /// /// Runs the MakeSfxCA command-line tool. /// /// Command-line arguments. /// 0 on success, nonzero on failure. public static int Main(string[] args) { var logger = TextWriter.Null; var output = String.Empty; var sfxDll = String.Empty; var inputs = new List(); var expandedArgs = ExpandArguments(args); foreach (var arg in expandedArgs) { if (arg == "-v") { logger = Console.Out; } else if (String.IsNullOrEmpty(output)) { output = arg; } else if (String.IsNullOrEmpty(sfxDll)) { sfxDll = arg; } else { inputs.Add(arg); } } if (inputs.Count == 0) { Usage(Console.Out); return 1; } try { Build(output, sfxDll, inputs, logger); return 0; } catch (ArgumentException ex) { Console.Error.WriteLine("Error: Invalid argument: " + ex.Message); return 1; } catch (FileNotFoundException ex) { Console.Error.WriteLine("Error: Cannot find file: " + ex.Message); return 1; } catch (Exception ex) { Console.Error.WriteLine("Error: Unexpected error: " + ex); return 1; } } /// /// Read the arguments include parsing response files. /// /// Arguments to expand /// Expanded list of arguments private static List ExpandArguments(string[] args) { var result = new List(args.Length); foreach (var arg in args) { if (String.IsNullOrWhiteSpace(arg)) { } else if (arg.StartsWith("@")) { var parsed = File.ReadAllLines(arg.Substring(1)); result.AddRange(parsed.Select(p => p.Trim('"')).Where(p => !String.IsNullOrWhiteSpace(p))); } else { result.Add(arg); } } return result; } /// /// Packages up all the inputs to the output location. /// /// Various exceptions are thrown /// if things go wrong. private static void Build(string output, string sfxDll, IList inputs, TextWriter log) { MakeSfxCA.log = log; if (String.IsNullOrEmpty(output)) { throw new ArgumentNullException("output"); } if (String.IsNullOrEmpty(sfxDll)) { throw new ArgumentNullException("sfxDll"); } if (inputs == null || inputs.Count == 0) { throw new ArgumentNullException("inputs"); } if (!File.Exists(sfxDll)) { throw new FileNotFoundException(sfxDll); } var customActionAssembly = inputs[0]; if (!File.Exists(customActionAssembly)) { throw new FileNotFoundException(customActionAssembly); } inputs = MakeSfxCA.SplitList(inputs); var inputsMap = MakeSfxCA.GetPackFileMap(inputs); var foundWIAssembly = false; foreach (var input in inputsMap.Keys) { if (String.Compare(input, MakeSfxCA.REQUIRED_WI_ASSEMBLY, StringComparison.OrdinalIgnoreCase) == 0) { foundWIAssembly = true; } } if (!foundWIAssembly) { throw new ArgumentException(MakeSfxCA.REQUIRED_WI_ASSEMBLY + " must be included in the list of support files. " + "If using the MSBuild targets, make sure the assembly reference " + "has the Private (Copy Local) flag set."); } MakeSfxCA.ResolveDependentAssemblies(inputsMap, Path.GetDirectoryName(customActionAssembly)); var entryPoints = MakeSfxCA.FindEntryPoints(customActionAssembly); var uiClass = MakeSfxCA.FindEmbeddedUIClass(customActionAssembly); if (entryPoints.Count == 0 && uiClass == null) { throw new ArgumentException( "No CA or UI entry points found in module: " + customActionAssembly); } else if (entryPoints.Count > 0 && uiClass != null) { throw new NotSupportedException( "CA and UI entry points cannot be in the same assembly: " + customActionAssembly); } var dir = Path.GetDirectoryName(output); if (dir.Length > 0 && !Directory.Exists(dir)) { Directory.CreateDirectory(dir); } using (Stream outputStream = File.Create(output)) { MakeSfxCA.WriteEntryModule(sfxDll, outputStream, entryPoints, uiClass); } MakeSfxCA.CopyVersionResource(customActionAssembly, output); MakeSfxCA.PackInputFiles(output, inputsMap); log.WriteLine("MakeSfxCA finished: " + new FileInfo(output).FullName); } /// /// Splits any list items delimited by semicolons into separate items. /// /// Read-only input list. /// New list with resulting split items. private static IList SplitList(IList list) { var newList = new List(list.Count); foreach (var item in list) { if (!String.IsNullOrEmpty(item)) { foreach (var splitItem in item.Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries)) { newList.Add(splitItem); } } } return newList; } /// /// Sets up a reflection-only assembly-resolve-handler to handle loading dependent assemblies during reflection. /// /// List of input files which include non-GAC dependent assemblies. /// Directory to auto-locate additional dependent assemblies. /// /// Also searches the assembly's directory for unspecified dependent assemblies, and adds them /// to the list of input files if found. /// private static void ResolveDependentAssemblies(IDictionary inputFiles, string inputDir) { AppDomain.CurrentDomain.ReflectionOnlyAssemblyResolve += delegate (object sender, ResolveEventArgs args) { AssemblyName resolveName = new AssemblyName(args.Name); Assembly assembly = null; // First, try to find the assembly in the list of input files. foreach (var inputFile in inputFiles.Values) { var inputName = Path.GetFileNameWithoutExtension(inputFile); var inputExtension = Path.GetExtension(inputFile); if (String.Equals(inputName, resolveName.Name, StringComparison.OrdinalIgnoreCase) && (String.Equals(inputExtension, ".dll", StringComparison.OrdinalIgnoreCase) || String.Equals(inputExtension, ".exe", StringComparison.OrdinalIgnoreCase))) { assembly = MakeSfxCA.TryLoadDependentAssembly(inputFile); if (assembly != null) { break; } } } // Second, try to find the assembly in the input directory. if (assembly == null && inputDir != null) { string assemblyPath = null; if (File.Exists(Path.Combine(inputDir, resolveName.Name) + ".dll")) { assemblyPath = Path.Combine(inputDir, resolveName.Name) + ".dll"; } else if (File.Exists(Path.Combine(inputDir, resolveName.Name) + ".exe")) { assemblyPath = Path.Combine(inputDir, resolveName.Name) + ".exe"; } if (assemblyPath != null) { assembly = MakeSfxCA.TryLoadDependentAssembly(assemblyPath); if (assembly != null) { // Add this detected dependency to the list of files to be packed. inputFiles.Add(Path.GetFileName(assemblyPath), assemblyPath); } } } // Third, try to load the assembly from the GAC. if (assembly == null) { try { assembly = Assembly.ReflectionOnlyLoad(args.Name); } catch (FileNotFoundException) { } } if (assembly != null) { if (String.Equals(assembly.GetName().ToString(), resolveName.ToString())) { log.WriteLine(" Loaded dependent assembly: " + assembly.Location); return assembly; } log.WriteLine(" Warning: Loaded mismatched dependent assembly: " + assembly.Location); log.WriteLine(" Loaded assembly : " + assembly.GetName()); log.WriteLine(" Reference assembly: " + resolveName); } else { log.WriteLine(" Error: Dependent assembly not supplied: " + resolveName); } return null; }; } /// /// Attempts a reflection-only load of a dependent assembly, logging the error if the load fails. /// /// Path of the assembly file to laod. /// Loaded assembly, or null if the load failed. private static Assembly TryLoadDependentAssembly(string assemblyPath) { Assembly assembly = null; try { assembly = Assembly.ReflectionOnlyLoadFrom(assemblyPath); } catch (IOException ex) { log.WriteLine(" Error: Failed to load dependent assembly: {0}. {1}", assemblyPath, ex.Message); } catch (BadImageFormatException ex) { log.WriteLine(" Error: Failed to load dependent assembly: {0}. {1}", assemblyPath, ex.Message); } catch (SecurityException ex) { log.WriteLine(" Error: Failed to load dependent assembly: {0}. {1}", assemblyPath, ex.Message); } return assembly; } /// /// Searches the types in the input assembly for a type that implements IEmbeddedUI. /// /// /// private static string FindEmbeddedUIClass(string module) { log.WriteLine("Searching for an embedded UI class in {0}", Path.GetFileName(module)); string uiClass = null; var assembly = Assembly.ReflectionOnlyLoadFrom(module); foreach (var type in assembly.GetExportedTypes()) { if (!type.IsAbstract) { foreach (var interfaceType in type.GetInterfaces()) { if (interfaceType.FullName == "WixToolset.Dtf.WindowsInstaller.IEmbeddedUI") { if (uiClass == null) { uiClass = assembly.GetName().Name + "!" + type.FullName; } else { throw new ArgumentException("Multiple IEmbeddedUI implementations found."); } } } } } return uiClass; } /// /// Reflects on an input CA module to locate custom action entry-points. /// /// Assembly module with CA entry-points. /// Mapping from entry-point names to assembly!class.method paths. private static IDictionary FindEntryPoints(string module) { log.WriteLine("Searching for custom action entry points " + "in {0}", Path.GetFileName(module)); var entryPoints = new Dictionary(); var assembly = Assembly.ReflectionOnlyLoadFrom(module); foreach (var type in assembly.GetExportedTypes()) { foreach (var method in type.GetMethods(BindingFlags.Public | BindingFlags.Static)) { var entryPointName = MakeSfxCA.GetEntryPoint(method); if (entryPointName != null) { var entryPointPath = String.Format( "{0}!{1}.{2}", Path.GetFileNameWithoutExtension(module), type.FullName, method.Name); entryPoints.Add(entryPointName, entryPointPath); log.WriteLine(" {0}={1}", entryPointName, entryPointPath); } } } return entryPoints; } /// /// Check for a CustomActionAttribute and return the entrypoint name for the method if it is a CA method. /// /// A public static method. /// Entrypoint name for the method as specified by the custom action attribute or just the method name, /// or null if the method is not a custom action method. private static string GetEntryPoint(MethodInfo method) { IList attributes; try { attributes = CustomAttributeData.GetCustomAttributes(method); } catch (FileLoadException) { // Already logged load failures in the assembly-resolve-handler. return null; } foreach (CustomAttributeData attribute in attributes) { if (attribute.ToString().StartsWith( "[WixToolset.Dtf.WindowsInstaller.CustomActionAttribute(", StringComparison.Ordinal)) { string entryPointName = null; foreach (var argument in attribute.ConstructorArguments) { // The entry point name is the first positional argument, if specified. entryPointName = (string)argument.Value; break; } if (String.IsNullOrEmpty(entryPointName)) { entryPointName = method.Name; } return entryPointName; } } return null; } /// /// Counts the number of template entrypoints in SfxCA.dll. /// /// /// Depending on the requirements, SfxCA.dll might be built with /// more entrypoints than the default. /// private static int GetEntryPointSlotCount(byte[] fileBytes, string entryPointFormat) { for (var count = 0; ; count++) { var templateName = String.Format(entryPointFormat, count); var templateAsciiBytes = Encoding.ASCII.GetBytes(templateName); var nameOffset = FindBytes(fileBytes, templateAsciiBytes); if (nameOffset < 0) { return count; } } } /// /// Writes a modified version of SfxCA.dll to the output stream, /// with the template entry-points mapped to the CA entry-points. /// /// /// To avoid having to recompile SfxCA.dll for every different set of CAs, /// this method looks for a preset number of template entry-points in the /// binary file and overwrites their entrypoint name and string data with /// CA-specific values. /// private static void WriteEntryModule( string sfxDll, Stream outputStream, IDictionary entryPoints, string uiClass) { log.WriteLine("Modifying SfxCA.dll stub"); byte[] fileBytes; using (var readStream = File.OpenRead(sfxDll)) { fileBytes = new byte[(int)readStream.Length]; readStream.Read(fileBytes, 0, fileBytes.Length); } const string ENTRYPOINT_FORMAT = "CustomActionEntryPoint{0:d03}"; const int MAX_ENTRYPOINT_NAME = 72; const int MAX_ENTRYPOINT_PATH = 160; //var emptyBytes = new byte[0]; var slotCount = MakeSfxCA.GetEntryPointSlotCount(fileBytes, ENTRYPOINT_FORMAT); if (slotCount == 0) { throw new ArgumentException("Invalid SfxCA.dll file."); } if (entryPoints.Count > slotCount) { throw new ArgumentException(String.Format( "The custom action assembly has {0} entrypoints, which is more than the maximum ({1}). " + "Refactor the custom actions or add more entrypoint slots in SfxCA\\EntryPoints.h.", entryPoints.Count, slotCount)); } var slotSort = new string[slotCount]; for (var i = 0; i < slotCount - entryPoints.Count; i++) { slotSort[i] = String.Empty; } entryPoints.Keys.CopyTo(slotSort, slotCount - entryPoints.Count); Array.Sort(slotSort, slotCount - entryPoints.Count, entryPoints.Count, StringComparer.Ordinal); for (var i = 0; ; i++) { var templateName = String.Format(ENTRYPOINT_FORMAT, i); var templateAsciiBytes = Encoding.ASCII.GetBytes(templateName); var templateUniBytes = Encoding.Unicode.GetBytes(templateName); var nameOffset = MakeSfxCA.FindBytes(fileBytes, templateAsciiBytes); if (nameOffset < 0) { break; } var pathOffset = MakeSfxCA.FindBytes(fileBytes, templateUniBytes); if (pathOffset < 0) { break; } var entryPointName = slotSort[i]; var entryPointPath = entryPointName.Length > 0 ? entryPoints[entryPointName] : String.Empty; if (entryPointName.Length > MAX_ENTRYPOINT_NAME) { throw new ArgumentException(String.Format( "Entry point name exceeds limit of {0} characters: {1}", MAX_ENTRYPOINT_NAME, entryPointName)); } if (entryPointPath.Length > MAX_ENTRYPOINT_PATH) { throw new ArgumentException(String.Format( "Entry point path exceeds limit of {0} characters: {1}", MAX_ENTRYPOINT_PATH, entryPointPath)); } var replaceNameBytes = Encoding.ASCII.GetBytes(entryPointName); var replacePathBytes = Encoding.Unicode.GetBytes(entryPointPath); MakeSfxCA.ReplaceBytes(fileBytes, nameOffset, MAX_ENTRYPOINT_NAME, replaceNameBytes); MakeSfxCA.ReplaceBytes(fileBytes, pathOffset, MAX_ENTRYPOINT_PATH * 2, replacePathBytes); } if (entryPoints.Count == 0 && uiClass != null) { // Remove the zzz prefix from exported EmbeddedUI entry-points. foreach (var export in new string[] { "InitializeEmbeddedUI", "EmbeddedUIHandler", "ShutdownEmbeddedUI" }) { var exportNameBytes = Encoding.ASCII.GetBytes("zzz" + export); var exportOffset = MakeSfxCA.FindBytes(fileBytes, exportNameBytes); if (exportOffset < 0) { throw new ArgumentException("Input SfxCA.dll does not contain exported entry-point: " + export); } var replaceNameBytes = Encoding.ASCII.GetBytes(export); MakeSfxCA.ReplaceBytes(fileBytes, exportOffset, exportNameBytes.Length, replaceNameBytes); } if (uiClass.Length > MAX_ENTRYPOINT_PATH) { throw new ArgumentException(String.Format( "UI class full name exceeds limit of {0} characters: {1}", MAX_ENTRYPOINT_PATH, uiClass)); } var templateBytes = Encoding.Unicode.GetBytes("InitializeEmbeddedUI_FullClassName"); var replaceBytes = Encoding.Unicode.GetBytes(uiClass); // Fill in the embedded UI implementor class so the proxy knows which one to load. var replaceOffset = MakeSfxCA.FindBytes(fileBytes, templateBytes); if (replaceOffset >= 0) { MakeSfxCA.ReplaceBytes(fileBytes, replaceOffset, MAX_ENTRYPOINT_PATH * 2, replaceBytes); } } outputStream.Write(fileBytes, 0, fileBytes.Length); } /// /// Searches for a sub-array of bytes within a larger array of bytes. /// private static int FindBytes(byte[] source, byte[] find) { for (var i = 0; i < source.Length; i++) { int j; for (j = 0; j < find.Length; j++) { if (source[i + j] != find[j]) { break; } } if (j == find.Length) { return i; } } return -1; } /// /// Replaces a range of bytes with new bytes, padding any extra part /// of the range with zeroes. /// private static void ReplaceBytes( byte[] source, int offset, int length, byte[] replace) { for (var i = 0; i < length; i++) { if (i < replace.Length) { source[offset + i] = replace[i]; } else { source[offset + i] = 0; } } } /// /// Print the name of one file as it is being packed into the cab. /// private static void PackProgress(object source, ArchiveProgressEventArgs e) { if (e.ProgressType == ArchiveProgressType.StartFile && log != null) { log.WriteLine(" {0}", e.CurrentFileName); } } /// /// Gets a mapping from filenames as they will be in the cab to filenames /// as they are currently on disk. /// /// /// By default, all files will be placed in the root of the cab. But inputs may /// optionally include an alternate inside-cab file path before an equals sign. /// private static IDictionary GetPackFileMap(IList inputs) { var fileMap = new Dictionary(); foreach (var inputFile in inputs) { if (inputFile.IndexOf('=') > 0) { var parse = inputFile.Split('='); if (!fileMap.ContainsKey(parse[0])) { fileMap.Add(parse[0], parse[1]); } } else { var fileName = Path.GetFileName(inputFile); if (!fileMap.ContainsKey(fileName)) { fileMap.Add(fileName, inputFile); } } } return fileMap; } /// /// Packs the input files into a cab that is appended to the /// output SfxCA.dll. /// private static void PackInputFiles(string outputFile, IDictionary fileMap) { log.WriteLine("Packaging files"); var cabInfo = new CabInfo(outputFile); cabInfo.PackFileSet(null, fileMap, CompressionLevel.Max, PackProgress); } /// /// Copies the version resource information from the CA module to /// the CA package. This gives the package the file version and /// description of the CA module, instead of the version and /// description of SfxCA.dll. /// private static void CopyVersionResource(string sourceFile, string destFile) { log.WriteLine("Copying file version info from {0} to {1}", sourceFile, destFile); var rc = new ResourceCollection(); rc.Find(sourceFile, ResourceType.Version); rc.Load(sourceFile); rc.Save(destFile); } } }