// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information.
namespace WixToolset.Dtf.MakeSfxCA
{
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Security;
using System.Text;
using WixToolset.Dtf.Compression;
using WixToolset.Dtf.Compression.Cab;
using WixToolset.Dtf.Resources;
using ResourceCollection = WixToolset.Dtf.Resources.ResourceCollection;
///
/// Command-line tool for building self-extracting custom action packages.
/// Appends cabbed CA binaries to SfxCA.dll and fixes up the result's
/// entry-points and file version to look like the CA module.
///
public static class MakeSfxCA
{
private const string REQUIRED_WI_ASSEMBLY = "WixToolset.Dtf.WindowsInstaller.dll";
private static TextWriter log;
///
/// Prints usage text for the tool.
///
/// Console text writer.
private static void Usage(TextWriter w)
{
w.WriteLine("WiX Toolset custom action packager version {0}", Assembly.GetExecutingAssembly().GetName().Version);
w.WriteLine("Copyright (C) .NET Foundation and contributors. All rights reserved.");
w.WriteLine();
w.WriteLine("Usage: WixToolset.Dtf.MakeSfxCA [-v] SfxCA.dll [support files ...]");
w.WriteLine();
w.WriteLine("Makes a self-extracting managed MSI CA or UI DLL package.");
w.WriteLine("Support files must include " + MakeSfxCA.REQUIRED_WI_ASSEMBLY);
w.WriteLine("Support files optionally include CustomAction.config/EmbeddedUI.config");
}
///
/// Runs the MakeSfxCA command-line tool.
///
/// Command-line arguments.
/// 0 on success, nonzero on failure.
public static int Main(string[] args)
{
var logger = TextWriter.Null;
var output = String.Empty;
var sfxDll = String.Empty;
var inputs = new List();
var expandedArgs = ExpandArguments(args);
foreach (var arg in expandedArgs)
{
if (arg == "-v")
{
logger = Console.Out;
}
else if (String.IsNullOrEmpty(output))
{
output = arg;
}
else if (String.IsNullOrEmpty(sfxDll))
{
sfxDll = arg;
}
else
{
inputs.Add(arg);
}
}
if (inputs.Count == 0)
{
Usage(Console.Out);
return 1;
}
try
{
Build(output, sfxDll, inputs, logger);
return 0;
}
catch (ArgumentException ex)
{
Console.Error.WriteLine("Error: Invalid argument: " + ex.Message);
return 1;
}
catch (FileNotFoundException ex)
{
Console.Error.WriteLine("Error: Cannot find file: " + ex.Message);
return 1;
}
catch (Exception ex)
{
Console.Error.WriteLine("Error: Unexpected error: " + ex);
return 1;
}
}
///
/// Read the arguments include parsing response files.
///
/// Arguments to expand
/// Expanded list of arguments
private static List ExpandArguments(string[] args)
{
var result = new List(args.Length);
foreach (var arg in args)
{
if (String.IsNullOrWhiteSpace(arg))
{
}
else if (arg.StartsWith("@"))
{
var parsed = File.ReadAllLines(arg.Substring(1));
result.AddRange(parsed.Select(p => p.Trim('"')).Where(p => !String.IsNullOrWhiteSpace(p)));
}
else
{
result.Add(arg);
}
}
return result;
}
///
/// Packages up all the inputs to the output location.
///
/// Various exceptions are thrown
/// if things go wrong.
private static void Build(string output, string sfxDll, IList inputs, TextWriter log)
{
MakeSfxCA.log = log;
if (String.IsNullOrEmpty(output))
{
throw new ArgumentNullException("output");
}
if (String.IsNullOrEmpty(sfxDll))
{
throw new ArgumentNullException("sfxDll");
}
if (inputs == null || inputs.Count == 0)
{
throw new ArgumentNullException("inputs");
}
if (!File.Exists(sfxDll))
{
throw new FileNotFoundException(sfxDll);
}
var customActionAssembly = inputs[0];
if (!File.Exists(customActionAssembly))
{
throw new FileNotFoundException(customActionAssembly);
}
inputs = MakeSfxCA.SplitList(inputs);
var inputsMap = MakeSfxCA.GetPackFileMap(inputs);
var foundWIAssembly = false;
foreach (var input in inputsMap.Keys)
{
if (String.Compare(input, MakeSfxCA.REQUIRED_WI_ASSEMBLY,
StringComparison.OrdinalIgnoreCase) == 0)
{
foundWIAssembly = true;
}
}
if (!foundWIAssembly)
{
throw new ArgumentException(MakeSfxCA.REQUIRED_WI_ASSEMBLY +
" must be included in the list of support files. " +
"If using the MSBuild targets, make sure the assembly reference " +
"has the Private (Copy Local) flag set.");
}
MakeSfxCA.ResolveDependentAssemblies(inputsMap, Path.GetDirectoryName(customActionAssembly));
var entryPoints = MakeSfxCA.FindEntryPoints(customActionAssembly);
var uiClass = MakeSfxCA.FindEmbeddedUIClass(customActionAssembly);
if (entryPoints.Count == 0 && uiClass == null)
{
throw new ArgumentException(
"No CA or UI entry points found in module: " + customActionAssembly);
}
else if (entryPoints.Count > 0 && uiClass != null)
{
throw new NotSupportedException(
"CA and UI entry points cannot be in the same assembly: " + customActionAssembly);
}
var dir = Path.GetDirectoryName(output);
if (dir.Length > 0 && !Directory.Exists(dir))
{
Directory.CreateDirectory(dir);
}
using (Stream outputStream = File.Create(output))
{
MakeSfxCA.WriteEntryModule(sfxDll, outputStream, entryPoints, uiClass);
}
MakeSfxCA.CopyVersionResource(customActionAssembly, output);
MakeSfxCA.PackInputFiles(output, inputsMap);
log.WriteLine("MakeSfxCA finished: " + new FileInfo(output).FullName);
}
///
/// Splits any list items delimited by semicolons into separate items.
///
/// Read-only input list.
/// New list with resulting split items.
private static IList SplitList(IList list)
{
var newList = new List(list.Count);
foreach (var item in list)
{
if (!String.IsNullOrEmpty(item))
{
foreach (var splitItem in item.Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries))
{
newList.Add(splitItem);
}
}
}
return newList;
}
///
/// Sets up a reflection-only assembly-resolve-handler to handle loading dependent assemblies during reflection.
///
/// List of input files which include non-GAC dependent assemblies.
/// Directory to auto-locate additional dependent assemblies.
///
/// Also searches the assembly's directory for unspecified dependent assemblies, and adds them
/// to the list of input files if found.
///
private static void ResolveDependentAssemblies(IDictionary inputFiles, string inputDir)
{
AppDomain.CurrentDomain.ReflectionOnlyAssemblyResolve += delegate (object sender, ResolveEventArgs args)
{
AssemblyName resolveName = new AssemblyName(args.Name);
Assembly assembly = null;
// First, try to find the assembly in the list of input files.
foreach (var inputFile in inputFiles.Values)
{
var inputName = Path.GetFileNameWithoutExtension(inputFile);
var inputExtension = Path.GetExtension(inputFile);
if (String.Equals(inputName, resolveName.Name, StringComparison.OrdinalIgnoreCase) &&
(String.Equals(inputExtension, ".dll", StringComparison.OrdinalIgnoreCase) ||
String.Equals(inputExtension, ".exe", StringComparison.OrdinalIgnoreCase)))
{
assembly = MakeSfxCA.TryLoadDependentAssembly(inputFile);
if (assembly != null)
{
break;
}
}
}
// Second, try to find the assembly in the input directory.
if (assembly == null && inputDir != null)
{
string assemblyPath = null;
if (File.Exists(Path.Combine(inputDir, resolveName.Name) + ".dll"))
{
assemblyPath = Path.Combine(inputDir, resolveName.Name) + ".dll";
}
else if (File.Exists(Path.Combine(inputDir, resolveName.Name) + ".exe"))
{
assemblyPath = Path.Combine(inputDir, resolveName.Name) + ".exe";
}
if (assemblyPath != null)
{
assembly = MakeSfxCA.TryLoadDependentAssembly(assemblyPath);
if (assembly != null)
{
// Add this detected dependency to the list of files to be packed.
inputFiles.Add(Path.GetFileName(assemblyPath), assemblyPath);
}
}
}
// Third, try to load the assembly from the GAC.
if (assembly == null)
{
try
{
assembly = Assembly.ReflectionOnlyLoad(args.Name);
}
catch (FileNotFoundException)
{
}
}
if (assembly != null)
{
if (String.Equals(assembly.GetName().ToString(), resolveName.ToString()))
{
log.WriteLine(" Loaded dependent assembly: " + assembly.Location);
return assembly;
}
log.WriteLine(" Warning: Loaded mismatched dependent assembly: " + assembly.Location);
log.WriteLine(" Loaded assembly : " + assembly.GetName());
log.WriteLine(" Reference assembly: " + resolveName);
}
else
{
log.WriteLine(" Error: Dependent assembly not supplied: " + resolveName);
}
return null;
};
}
///
/// Attempts a reflection-only load of a dependent assembly, logging the error if the load fails.
///
/// Path of the assembly file to laod.
/// Loaded assembly, or null if the load failed.
private static Assembly TryLoadDependentAssembly(string assemblyPath)
{
Assembly assembly = null;
try
{
assembly = Assembly.ReflectionOnlyLoadFrom(assemblyPath);
}
catch (IOException ex)
{
log.WriteLine(" Error: Failed to load dependent assembly: {0}. {1}", assemblyPath, ex.Message);
}
catch (BadImageFormatException ex)
{
log.WriteLine(" Error: Failed to load dependent assembly: {0}. {1}", assemblyPath, ex.Message);
}
catch (SecurityException ex)
{
log.WriteLine(" Error: Failed to load dependent assembly: {0}. {1}", assemblyPath, ex.Message);
}
return assembly;
}
///
/// Searches the types in the input assembly for a type that implements IEmbeddedUI.
///
///
///
private static string FindEmbeddedUIClass(string module)
{
log.WriteLine("Searching for an embedded UI class in {0}", Path.GetFileName(module));
string uiClass = null;
var assembly = Assembly.ReflectionOnlyLoadFrom(module);
foreach (var type in assembly.GetExportedTypes())
{
if (!type.IsAbstract)
{
foreach (var interfaceType in type.GetInterfaces())
{
if (interfaceType.FullName == "WixToolset.Dtf.WindowsInstaller.IEmbeddedUI")
{
if (uiClass == null)
{
uiClass = assembly.GetName().Name + "!" + type.FullName;
}
else
{
throw new ArgumentException("Multiple IEmbeddedUI implementations found.");
}
}
}
}
}
return uiClass;
}
///
/// Reflects on an input CA module to locate custom action entry-points.
///
/// Assembly module with CA entry-points.
/// Mapping from entry-point names to assembly!class.method paths.
private static IDictionary FindEntryPoints(string module)
{
log.WriteLine("Searching for custom action entry points " +
"in {0}", Path.GetFileName(module));
var entryPoints = new Dictionary();
var assembly = Assembly.ReflectionOnlyLoadFrom(module);
foreach (var type in assembly.GetExportedTypes())
{
foreach (var method in type.GetMethods(BindingFlags.Public | BindingFlags.Static))
{
var entryPointName = MakeSfxCA.GetEntryPoint(method);
if (entryPointName != null)
{
var entryPointPath = String.Format(
"{0}!{1}.{2}",
Path.GetFileNameWithoutExtension(module),
type.FullName,
method.Name);
entryPoints.Add(entryPointName, entryPointPath);
log.WriteLine(" {0}={1}", entryPointName, entryPointPath);
}
}
}
return entryPoints;
}
///
/// Check for a CustomActionAttribute and return the entrypoint name for the method if it is a CA method.
///
/// A public static method.
/// Entrypoint name for the method as specified by the custom action attribute or just the method name,
/// or null if the method is not a custom action method.
private static string GetEntryPoint(MethodInfo method)
{
IList attributes;
try
{
attributes = CustomAttributeData.GetCustomAttributes(method);
}
catch (FileLoadException)
{
// Already logged load failures in the assembly-resolve-handler.
return null;
}
foreach (CustomAttributeData attribute in attributes)
{
if (attribute.ToString().StartsWith(
"[WixToolset.Dtf.WindowsInstaller.CustomActionAttribute(",
StringComparison.Ordinal))
{
string entryPointName = null;
foreach (var argument in attribute.ConstructorArguments)
{
// The entry point name is the first positional argument, if specified.
entryPointName = (string)argument.Value;
break;
}
if (String.IsNullOrEmpty(entryPointName))
{
entryPointName = method.Name;
}
return entryPointName;
}
}
return null;
}
///
/// Counts the number of template entrypoints in SfxCA.dll.
///
///
/// Depending on the requirements, SfxCA.dll might be built with
/// more entrypoints than the default.
///
private static int GetEntryPointSlotCount(byte[] fileBytes, string entryPointFormat)
{
for (var count = 0; ; count++)
{
var templateName = String.Format(entryPointFormat, count);
var templateAsciiBytes = Encoding.ASCII.GetBytes(templateName);
var nameOffset = FindBytes(fileBytes, templateAsciiBytes);
if (nameOffset < 0)
{
return count;
}
}
}
///
/// Writes a modified version of SfxCA.dll to the output stream,
/// with the template entry-points mapped to the CA entry-points.
///
///
/// To avoid having to recompile SfxCA.dll for every different set of CAs,
/// this method looks for a preset number of template entry-points in the
/// binary file and overwrites their entrypoint name and string data with
/// CA-specific values.
///
private static void WriteEntryModule(
string sfxDll, Stream outputStream, IDictionary entryPoints, string uiClass)
{
log.WriteLine("Modifying SfxCA.dll stub");
byte[] fileBytes;
using (var readStream = File.OpenRead(sfxDll))
{
fileBytes = new byte[(int)readStream.Length];
readStream.Read(fileBytes, 0, fileBytes.Length);
}
const string ENTRYPOINT_FORMAT = "CustomActionEntryPoint{0:d03}";
const int MAX_ENTRYPOINT_NAME = 72;
const int MAX_ENTRYPOINT_PATH = 160;
//var emptyBytes = new byte[0];
var slotCount = MakeSfxCA.GetEntryPointSlotCount(fileBytes, ENTRYPOINT_FORMAT);
if (slotCount == 0)
{
throw new ArgumentException("Invalid SfxCA.dll file.");
}
if (entryPoints.Count > slotCount)
{
throw new ArgumentException(String.Format(
"The custom action assembly has {0} entrypoints, which is more than the maximum ({1}). " +
"Refactor the custom actions or add more entrypoint slots in SfxCA\\EntryPoints.h.",
entryPoints.Count, slotCount));
}
var slotSort = new string[slotCount];
for (var i = 0; i < slotCount - entryPoints.Count; i++)
{
slotSort[i] = String.Empty;
}
entryPoints.Keys.CopyTo(slotSort, slotCount - entryPoints.Count);
Array.Sort(slotSort, slotCount - entryPoints.Count, entryPoints.Count, StringComparer.Ordinal);
for (var i = 0; ; i++)
{
var templateName = String.Format(ENTRYPOINT_FORMAT, i);
var templateAsciiBytes = Encoding.ASCII.GetBytes(templateName);
var templateUniBytes = Encoding.Unicode.GetBytes(templateName);
var nameOffset = MakeSfxCA.FindBytes(fileBytes, templateAsciiBytes);
if (nameOffset < 0)
{
break;
}
var pathOffset = MakeSfxCA.FindBytes(fileBytes, templateUniBytes);
if (pathOffset < 0)
{
break;
}
var entryPointName = slotSort[i];
var entryPointPath = entryPointName.Length > 0 ?
entryPoints[entryPointName] : String.Empty;
if (entryPointName.Length > MAX_ENTRYPOINT_NAME)
{
throw new ArgumentException(String.Format(
"Entry point name exceeds limit of {0} characters: {1}",
MAX_ENTRYPOINT_NAME,
entryPointName));
}
if (entryPointPath.Length > MAX_ENTRYPOINT_PATH)
{
throw new ArgumentException(String.Format(
"Entry point path exceeds limit of {0} characters: {1}",
MAX_ENTRYPOINT_PATH,
entryPointPath));
}
var replaceNameBytes = Encoding.ASCII.GetBytes(entryPointName);
var replacePathBytes = Encoding.Unicode.GetBytes(entryPointPath);
MakeSfxCA.ReplaceBytes(fileBytes, nameOffset, MAX_ENTRYPOINT_NAME, replaceNameBytes);
MakeSfxCA.ReplaceBytes(fileBytes, pathOffset, MAX_ENTRYPOINT_PATH * 2, replacePathBytes);
}
if (entryPoints.Count == 0 && uiClass != null)
{
// Remove the zzz prefix from exported EmbeddedUI entry-points.
foreach (var export in new string[] { "InitializeEmbeddedUI", "EmbeddedUIHandler", "ShutdownEmbeddedUI" })
{
var exportNameBytes = Encoding.ASCII.GetBytes("zzz" + export);
var exportOffset = MakeSfxCA.FindBytes(fileBytes, exportNameBytes);
if (exportOffset < 0)
{
throw new ArgumentException("Input SfxCA.dll does not contain exported entry-point: " + export);
}
var replaceNameBytes = Encoding.ASCII.GetBytes(export);
MakeSfxCA.ReplaceBytes(fileBytes, exportOffset, exportNameBytes.Length, replaceNameBytes);
}
if (uiClass.Length > MAX_ENTRYPOINT_PATH)
{
throw new ArgumentException(String.Format(
"UI class full name exceeds limit of {0} characters: {1}",
MAX_ENTRYPOINT_PATH,
uiClass));
}
var templateBytes = Encoding.Unicode.GetBytes("InitializeEmbeddedUI_FullClassName");
var replaceBytes = Encoding.Unicode.GetBytes(uiClass);
// Fill in the embedded UI implementor class so the proxy knows which one to load.
var replaceOffset = MakeSfxCA.FindBytes(fileBytes, templateBytes);
if (replaceOffset >= 0)
{
MakeSfxCA.ReplaceBytes(fileBytes, replaceOffset, MAX_ENTRYPOINT_PATH * 2, replaceBytes);
}
}
outputStream.Write(fileBytes, 0, fileBytes.Length);
}
///
/// Searches for a sub-array of bytes within a larger array of bytes.
///
private static int FindBytes(byte[] source, byte[] find)
{
for (var i = 0; i < source.Length; i++)
{
int j;
for (j = 0; j < find.Length; j++)
{
if (source[i + j] != find[j])
{
break;
}
}
if (j == find.Length)
{
return i;
}
}
return -1;
}
///
/// Replaces a range of bytes with new bytes, padding any extra part
/// of the range with zeroes.
///
private static void ReplaceBytes(
byte[] source, int offset, int length, byte[] replace)
{
for (var i = 0; i < length; i++)
{
if (i < replace.Length)
{
source[offset + i] = replace[i];
}
else
{
source[offset + i] = 0;
}
}
}
///
/// Print the name of one file as it is being packed into the cab.
///
private static void PackProgress(object source, ArchiveProgressEventArgs e)
{
if (e.ProgressType == ArchiveProgressType.StartFile && log != null)
{
log.WriteLine(" {0}", e.CurrentFileName);
}
}
///
/// Gets a mapping from filenames as they will be in the cab to filenames
/// as they are currently on disk.
///
///
/// By default, all files will be placed in the root of the cab. But inputs may
/// optionally include an alternate inside-cab file path before an equals sign.
///
private static IDictionary GetPackFileMap(IList inputs)
{
var fileMap = new Dictionary();
foreach (var inputFile in inputs)
{
if (inputFile.IndexOf('=') > 0)
{
var parse = inputFile.Split('=');
if (!fileMap.ContainsKey(parse[0]))
{
fileMap.Add(parse[0], parse[1]);
}
}
else
{
var fileName = Path.GetFileName(inputFile);
if (!fileMap.ContainsKey(fileName))
{
fileMap.Add(fileName, inputFile);
}
}
}
return fileMap;
}
///
/// Packs the input files into a cab that is appended to the
/// output SfxCA.dll.
///
private static void PackInputFiles(string outputFile, IDictionary fileMap)
{
log.WriteLine("Packaging files");
var cabInfo = new CabInfo(outputFile);
cabInfo.PackFileSet(null, fileMap, CompressionLevel.Max, PackProgress);
}
///
/// Copies the version resource information from the CA module to
/// the CA package. This gives the package the file version and
/// description of the CA module, instead of the version and
/// description of SfxCA.dll.
///
private static void CopyVersionResource(string sourceFile, string destFile)
{
log.WriteLine("Copying file version info from {0} to {1}",
sourceFile, destFile);
var rc = new ResourceCollection();
rc.Find(sourceFile, ResourceType.Version);
rc.Load(sourceFile);
rc.Save(destFile);
}
}
}