// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information. namespace WixToolset.Firewall { using System; using System.Collections.Generic; using System.Xml.Linq; using WixToolset.Data; using WixToolset.Extensibility; using WixToolset.Extensibility.Data; using WixToolset.Firewall.Symbols; /// /// The compiler for the WiX Toolset Firewall Extension. /// public sealed class FirewallCompiler : BaseCompilerExtension { public override XNamespace Namespace => "http://wixtoolset.org/schemas/v4/wxs/firewall"; /// /// Processes an element for the Compiler. /// /// Source line number for the parent element. /// Parent element of element to process. /// Element to process. /// Extra information about the context in which this element is being parsed. public override void ParseElement(Intermediate intermediate, IntermediateSection section, XElement parentElement, XElement element, IDictionary context) { switch (parentElement.Name.LocalName) { case "File": var fileId = context["FileId"]; var fileComponentId = context["ComponentId"]; switch (element.Name.LocalName) { case "FirewallException": this.ParseFirewallExceptionElement(intermediate, section, element, fileComponentId, fileId); break; default: this.ParseHelper.UnexpectedElement(parentElement, element); break; } break; case "Component": var componentId = context["ComponentId"]; switch (element.Name.LocalName) { case "FirewallException": this.ParseFirewallExceptionElement(intermediate, section, element, componentId, null); break; default: this.ParseHelper.UnexpectedElement(parentElement, element); break; } break; default: this.ParseHelper.UnexpectedElement(parentElement, element); break; } } /// /// Parses a FirewallException element. /// /// The element to parse. /// Identifier of the component that owns this firewall exception. /// The file identifier of the parent element (null if nested under Component). private void ParseFirewallExceptionElement(Intermediate intermediate, IntermediateSection section, XElement element, string componentId, string fileId) { var sourceLineNumbers = this.ParseHelper.GetSourceLineNumbers(element); Identifier id = null; string name = null; int attributes = 0; string file = null; string program = null; string port = null; int? protocol = null; int? profile = null; string scope = null; string remoteAddresses = null; string description = null; int? direction = null; foreach (var attrib in element.Attributes()) { if (String.IsNullOrEmpty(attrib.Name.NamespaceName) || this.Namespace == attrib.Name.Namespace) { switch (attrib.Name.LocalName) { case "Id": id = this.ParseHelper.GetAttributeIdentifier(sourceLineNumbers, attrib); break; case "Name": name = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); break; case "File": if (null != fileId) { this.Messaging.Write(ErrorMessages.IllegalAttributeWhenNested(sourceLineNumbers, element.Name.LocalName, "File", "File")); } else { file = this.ParseHelper.GetAttributeIdentifierValue(sourceLineNumbers, attrib); } break; case "IgnoreFailure": if (YesNoType.Yes == this.ParseHelper.GetAttributeYesNoValue(sourceLineNumbers, attrib)) { attributes |= 0x1; // feaIgnoreFailures } break; case "Program": if (null != fileId) { this.Messaging.Write(ErrorMessages.IllegalAttributeWhenNested(sourceLineNumbers, element.Name.LocalName, "Program", "File")); } else { program = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); } break; case "Port": port = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); break; case "Protocol": var protocolValue = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); switch (protocolValue) { case "tcp": protocol = FirewallConstants.NET_FW_IP_PROTOCOL_TCP; break; case "udp": protocol = FirewallConstants.NET_FW_IP_PROTOCOL_UDP; break; default: this.Messaging.Write(ErrorMessages.IllegalAttributeValue(sourceLineNumbers, element.Name.LocalName, "Protocol", protocolValue, "tcp", "udp")); break; } break; case "Scope": scope = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); switch (scope) { case "any": remoteAddresses = "*"; break; case "localSubnet": remoteAddresses = "LocalSubnet"; break; default: this.Messaging.Write(ErrorMessages.IllegalAttributeValue(sourceLineNumbers, element.Name.LocalName, "Scope", scope, "any", "localSubnet")); break; } break; case "Profile": var profileValue = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); switch (profileValue) { case "domain": profile = FirewallConstants.NET_FW_PROFILE2_DOMAIN; break; case "private": profile = FirewallConstants.NET_FW_PROFILE2_PRIVATE; break; case "public": profile = FirewallConstants.NET_FW_PROFILE2_PUBLIC; break; case "all": profile = FirewallConstants.NET_FW_PROFILE2_ALL; break; default: this.Messaging.Write(ErrorMessages.IllegalAttributeValue(sourceLineNumbers, element.Name.LocalName, "Profile", profileValue, "domain", "private", "public", "all")); break; } break; case "Description": description = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); break; case "Outbound": direction = this.ParseHelper.GetAttributeYesNoValue(sourceLineNumbers, attrib) == YesNoType.Yes ? FirewallConstants.NET_FW_RULE_DIR_OUT : FirewallConstants.NET_FW_RULE_DIR_IN; break; default: this.ParseHelper.UnexpectedAttribute(element, attrib); break; } } else { this.ParseHelper.ParseExtensionAttribute(this.Context.Extensions, intermediate, section, element, attrib); } } // parse RemoteAddress children foreach (var child in element.Elements()) { if (this.Namespace == child.Name.Namespace) { switch (child.Name.LocalName) { case "RemoteAddress": if (null != scope) { this.Messaging.Write(FirewallErrors.IllegalRemoteAddressWithScopeAttribute(sourceLineNumbers)); } else { this.ParseRemoteAddressElement(intermediate, section, child, ref remoteAddresses); } break; default: this.ParseHelper.UnexpectedElement(element, child); break; } } else { this.ParseHelper.ParseExtensionElement(this.Context.Extensions, intermediate, section, element, child); } } if (null == id) { id = this.ParseHelper.CreateIdentifier("fex", name, remoteAddresses, componentId); } // Name is required if (null == name) { this.Messaging.Write(ErrorMessages.ExpectedAttribute(sourceLineNumbers, element.Name.LocalName, "Name")); } // Scope or child RemoteAddress(es) are required if (null == remoteAddresses) { this.Messaging.Write(ErrorMessages.ExpectedAttributeOrElement(sourceLineNumbers, element.Name.LocalName, "Scope", "RemoteAddress")); } // can't have both Program and File if (null != program && null != file) { this.Messaging.Write(ErrorMessages.IllegalAttributeWithOtherAttribute(sourceLineNumbers, element.Name.LocalName, "File", "Program")); } // must be nested under File, have File or Program attributes, or have Port attribute if (String.IsNullOrEmpty(fileId) && String.IsNullOrEmpty(file) && String.IsNullOrEmpty(program) && String.IsNullOrEmpty(port)) { this.Messaging.Write(FirewallErrors.NoExceptionSpecified(sourceLineNumbers)); } if (!this.Messaging.EncounteredError) { // at this point, File attribute and File parent element are treated the same if (null != file) { fileId = file; } var symbol = section.AddSymbol(new WixFirewallExceptionSymbol(sourceLineNumbers, id) { Name = name, RemoteAddresses = remoteAddresses, Profile = profile ?? FirewallConstants.NET_FW_PROFILE2_ALL, ComponentRef = componentId, Description = description, Direction = direction ?? FirewallConstants.NET_FW_RULE_DIR_IN, }); if (!String.IsNullOrEmpty(port)) { symbol.Port = port; if (!protocol.HasValue) { // default protocol is "TCP" protocol = FirewallConstants.NET_FW_IP_PROTOCOL_TCP; } } if (protocol.HasValue) { symbol.Protocol = protocol.Value; } if (!String.IsNullOrEmpty(fileId)) { symbol.Program = $"[#{fileId}]"; this.ParseHelper.CreateSimpleReference(section, sourceLineNumbers, SymbolDefinitions.File, fileId); } else if (!String.IsNullOrEmpty(program)) { symbol.Program = program; } if (CompilerConstants.IntegerNotSet != attributes) { symbol.Attributes = attributes; } this.ParseHelper.CreateCustomActionReference(sourceLineNumbers, section, "Wix4SchedFirewallExceptionsInstall", this.Context.Platform, CustomActionPlatforms.ARM64 | CustomActionPlatforms.X64 | CustomActionPlatforms.X86); this.ParseHelper.CreateCustomActionReference(sourceLineNumbers, section, "Wix4SchedFirewallExceptionsUninstall", this.Context.Platform, CustomActionPlatforms.ARM64 | CustomActionPlatforms.X64 | CustomActionPlatforms.X86); } } /// /// Parses a RemoteAddress element /// /// The element to parse. private void ParseRemoteAddressElement(Intermediate intermediate, IntermediateSection section, XElement element, ref string remoteAddresses) { var sourceLineNumbers = this.ParseHelper.GetSourceLineNumbers(element); string address = null; // no attributes foreach (var attrib in element.Attributes()) { if (String.IsNullOrEmpty(attrib.Name.NamespaceName) || this.Namespace == attrib.Name.Namespace) { switch (attrib.Name.LocalName) { case "Value": address = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); break; } } else { this.ParseHelper.ParseExtensionAttribute(this.Context.Extensions, intermediate, section, element, attrib); } } this.ParseHelper.ParseForExtensionElements(this.Context.Extensions, intermediate, section, element); if (String.IsNullOrEmpty(address)) { this.Messaging.Write(ErrorMessages.ExpectedAttribute(sourceLineNumbers, element.Name.LocalName, "Value")); } else { if (String.IsNullOrEmpty(remoteAddresses)) { remoteAddresses = address; } else { remoteAddresses = String.Concat(remoteAddresses, ",", address); } } } } }