// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information.
namespace WixToolset.Core.ExtensionCache
{
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using NuGet.Common;
using NuGet.Configuration;
using NuGet.Credentials;
using NuGet.Packaging;
using NuGet.Protocol;
using NuGet.Protocol.Core.Types;
using NuGet.Versioning;
///
/// Extension cache manager.
///
internal class ExtensionCacheManager
{
public string CacheFolder(bool global) => global ? this.GlobalCacheFolder() : this.LocalCacheFolder();
public string LocalCacheFolder() => Path.Combine(Environment.CurrentDirectory, ".wix", "extensions");
public string GlobalCacheFolder()
{
var baseFolder = Environment.GetEnvironmentVariable("WIX_EXTENSIONS") ?? Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);
return Path.Combine(baseFolder, ".wix", "extensions");
}
public async Task AddAsync(bool global, string extension, CancellationToken cancellationToken)
{
if (String.IsNullOrEmpty(extension))
{
throw new ArgumentNullException(nameof(extension));
}
(var extensionId, var extensionVersion) = ParseExtensionReference(extension);
var result = await this.DownloadAndExtractAsync(global, extensionId, extensionVersion, cancellationToken);
return result;
}
public Task RemoveAsync(bool global, string extension, CancellationToken cancellationToken)
{
if (String.IsNullOrEmpty(extension))
{
throw new ArgumentNullException(nameof(extension));
}
(var extensionId, var extensionVersion) = ParseExtensionReference(extension);
var cacheFolder = this.CacheFolder(global);
cacheFolder = Path.Combine(cacheFolder, extensionId, extensionVersion);
if (Directory.Exists(cacheFolder))
{
cancellationToken.ThrowIfCancellationRequested();
Directory.Delete(cacheFolder, true);
return Task.FromResult(true);
}
return Task.FromResult(false);
}
public Task> ListAsync(bool global, string extension, CancellationToken cancellationToken)
{
var found = new List();
(var extensionId, var extensionVersion) = ParseExtensionReference(extension);
var cacheFolder = this.CacheFolder(global);
var searchFolder = Path.Combine(cacheFolder, extensionId, extensionVersion);
if (!Directory.Exists(searchFolder))
{
}
else if (!String.IsNullOrEmpty(extensionVersion)) // looking for an explicit version of an extension.
{
var present = ExtensionFileExists(cacheFolder, extensionId, extensionVersion);
found.Add(new CachedExtension(extensionId, extensionVersion, !present));
}
else // looking for all versions of an extension or all versions of all extensions.
{
IEnumerable foundExtensionIds;
if (String.IsNullOrEmpty(extensionId))
{
// Looking for all versions of all extensions.
foundExtensionIds = Directory.GetDirectories(cacheFolder).Select(folder => Path.GetFileName(folder)).ToList();
}
else
{
// Looking for all versions of a single extension.
var extensionFolder = Path.Combine(cacheFolder, extensionId);
foundExtensionIds = Directory.Exists(extensionFolder) ? new[] { extensionId } : Array.Empty();
}
foreach (var foundExtensionId in foundExtensionIds)
{
var extensionFolder = Path.Combine(cacheFolder, foundExtensionId);
foreach (var folder in Directory.GetDirectories(extensionFolder))
{
cancellationToken.ThrowIfCancellationRequested();
var foundExtensionVersion = Path.GetFileName(folder);
if (!NuGetVersion.TryParse(foundExtensionVersion, out _))
{
continue;
}
var present = ExtensionFileExists(cacheFolder, foundExtensionId, foundExtensionVersion);
found.Add(new CachedExtension(foundExtensionId, foundExtensionVersion, !present));
}
}
}
return Task.FromResult((IEnumerable)found);
}
private async Task DownloadAndExtractAsync(bool global, string id, string version, CancellationToken cancellationToken)
{
var logger = NullLogger.Instance;
DefaultCredentialServiceUtility.SetupDefaultCredentialService(logger, nonInteractive: false);
var settings = Settings.LoadDefaultSettings(root: Environment.CurrentDirectory);
var sources = PackageSourceProvider.LoadPackageSources(settings).Where(s => s.IsEnabled);
using (var cache = new SourceCacheContext())
{
PackageSource versionSource = null;
var nugetVersion = String.IsNullOrEmpty(version) ? null : new NuGetVersion(version);
if (nugetVersion is null)
{
foreach (var source in sources)
{
var repository = Repository.Factory.GetCoreV3(source.Source);
var resource = await repository.GetResourceAsync();
var availableVersions = await resource.GetAllVersionsAsync(id, cache, logger, cancellationToken);
foreach (var availableVersion in availableVersions)
{
if (nugetVersion is null || nugetVersion < availableVersion)
{
nugetVersion = availableVersion;
versionSource = source;
}
}
}
if (nugetVersion is null)
{
return false;
}
}
var searchSources = versionSource is null ? sources : new[] { versionSource };
var extensionFolder = Path.Combine(this.CacheFolder(global), id, nugetVersion.ToString());
foreach (var source in searchSources)
{
var repository = Repository.Factory.GetCoreV3(source.Source);
var resource = await repository.GetResourceAsync();
using (var stream = new MemoryStream())
{
var downloaded = await resource.CopyNupkgToStreamAsync(id, nugetVersion, stream, cache, logger, cancellationToken);
if (downloaded)
{
stream.Position = 0;
using (var archive = new PackageArchiveReader(stream))
{
var files = PackagingConstants.Folders.Known.SelectMany(folder => archive.GetFiles(folder)).Distinct(StringComparer.OrdinalIgnoreCase);
await archive.CopyFilesAsync(extensionFolder, files, this.ExtractProgress, logger, cancellationToken);
}
return true;
}
}
}
}
return false;
}
private string ExtractProgress(string sourceFile, string targetPath, Stream fileStream) => fileStream.CopyToFile(targetPath);
private static (string extensionId, string extensionVersion) ParseExtensionReference(string extensionReference)
{
var extensionId = extensionReference ?? String.Empty;
var extensionVersion = String.Empty;
var index = extensionId.LastIndexOf('/');
if (index > 0)
{
extensionVersion = extensionReference.Substring(index + 1);
extensionId = extensionReference.Substring(0, index);
if (!NuGetVersion.TryParse(extensionVersion, out _))
{
throw new ArgumentException($"Invalid extension version in {extensionReference}");
}
if (String.IsNullOrEmpty(extensionId))
{
throw new ArgumentException($"Invalid extension id in {extensionReference}");
}
}
return (extensionId, extensionVersion);
}
private static bool ExtensionFileExists(string baseFolder, string extensionId, string extensionVersion)
{
var toolsFolder = Path.Combine(baseFolder, extensionId, extensionVersion, "tools");
if (!Directory.Exists(toolsFolder))
{
return false;
}
var extensionAssembly = Path.Combine(toolsFolder, extensionId + ".dll");
var present = File.Exists(extensionAssembly);
if (!present)
{
extensionAssembly = Path.Combine(toolsFolder, extensionId + ".exe");
present = File.Exists(extensionAssembly);
}
return present;
}
}
}