1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System;
6 using System.Collections;
7 using System.IO;
8 using System.Runtime.InteropServices;
9 using System.Security;
10 using System.Text;
11 using System.Xml;
12 using System.Xml.XPath;
13 using System.Xml.Xsl;
14 
15 namespace System.Security.Cryptography.Xml
16 {
17     public class XmlDsigXsltTransform : Transform
18     {
19         private Type[] _inputTypes = { typeof(Stream), typeof(XmlDocument), typeof(XmlNodeList) };
20         private Type[] _outputTypes = { typeof(Stream) };
21         private XmlNodeList _xslNodes;
22         private string _xslFragment;
23         private Stream _inputStream;
24         private bool _includeComments = false;
25 
XmlDsigXsltTransform()26         public XmlDsigXsltTransform()
27         {
28             Algorithm = SignedXml.XmlDsigXsltTransformUrl;
29         }
30 
XmlDsigXsltTransform(bool includeComments)31         public XmlDsigXsltTransform(bool includeComments)
32         {
33             _includeComments = includeComments;
34             Algorithm = SignedXml.XmlDsigXsltTransformUrl;
35         }
36 
37         public override Type[] InputTypes
38         {
39             get
40             {
41                 return _inputTypes;
42             }
43         }
44 
45         public override Type[] OutputTypes
46         {
47             get
48             {
49                 return _outputTypes;
50             }
51         }
52 
LoadInnerXml(XmlNodeList nodeList)53         public override void LoadInnerXml(XmlNodeList nodeList)
54         {
55             if (nodeList == null)
56                 throw new CryptographicException(SR.Cryptography_Xml_UnknownTransform);
57             // check that the XSLT element is well formed
58             XmlElement firstDataElement = null;
59             int count = 0;
60             foreach (XmlNode node in nodeList)
61             {
62                 // ignore whitespace, but make sure only one child element is present
63                 if (node is XmlWhitespace) continue;
64                 if (node is XmlElement)
65                 {
66                     if (count != 0)
67                         throw new CryptographicException(SR.Cryptography_Xml_UnknownTransform);
68                     firstDataElement = node as XmlElement;
69                     count++;
70                     continue;
71                 }
72                 // Only allow whitespace
73                 count++;
74             }
75             if (count != 1 || firstDataElement == null)
76                 throw new CryptographicException(SR.Cryptography_Xml_UnknownTransform);
77             _xslNodes = nodeList;
78             _xslFragment = firstDataElement.OuterXml.Trim(null);
79         }
80 
GetInnerXml()81         protected override XmlNodeList GetInnerXml()
82         {
83             return _xslNodes;
84         }
85 
LoadInput(object obj)86         public override void LoadInput(object obj)
87         {
88             if (_inputStream != null)
89                 _inputStream.Close();
90             _inputStream = new MemoryStream();
91             if (obj is Stream)
92             {
93                 _inputStream = (Stream)obj;
94             }
95             else if (obj is XmlNodeList)
96             {
97                 CanonicalXml xmlDoc = new CanonicalXml((XmlNodeList)obj, null, _includeComments);
98                 byte[] buffer = xmlDoc.GetBytes();
99                 if (buffer == null) return;
100                 _inputStream.Write(buffer, 0, buffer.Length);
101                 _inputStream.Flush();
102                 _inputStream.Position = 0;
103             }
104             else if (obj is XmlDocument)
105             {
106                 CanonicalXml xmlDoc = new CanonicalXml((XmlDocument)obj, null, _includeComments);
107                 byte[] buffer = xmlDoc.GetBytes();
108                 if (buffer == null) return;
109                 _inputStream.Write(buffer, 0, buffer.Length);
110                 _inputStream.Flush();
111                 _inputStream.Position = 0;
112             }
113         }
114 
GetOutput()115         public override object GetOutput()
116         {
117             //  XSL transforms expose many powerful features by default:
118             //  1- we need to pass a null evidence to prevent script execution.
119             //  2- XPathDocument will expand entities, we don't want this, so set the resolver to null
120             //  3- We don't want the document function feature of XslTransforms.
121 
122             // load the XSL Transform
123             XslCompiledTransform xslt = new XslCompiledTransform();
124             XmlReaderSettings settings = new XmlReaderSettings();
125             settings.XmlResolver = null;
126             settings.MaxCharactersFromEntities = Utils.MaxCharactersFromEntities;
127             settings.MaxCharactersInDocument = Utils.MaxCharactersInDocument;
128             using (StringReader sr = new StringReader(_xslFragment))
129             {
130                 XmlReader readerXsl = XmlReader.Create(sr, settings, (string)null);
131                 xslt.Load(readerXsl, XsltSettings.Default, null);
132 
133                 // Now load the input stream, XmlDocument can be used but is less efficient
134                 XmlReader reader = XmlReader.Create(_inputStream, settings, BaseURI);
135                 XPathDocument inputData = new XPathDocument(reader, XmlSpace.Preserve);
136 
137                 // Create an XmlTextWriter
138                 MemoryStream ms = new MemoryStream();
139                 XmlWriter writer = new XmlTextWriter(ms, null);
140 
141                 // Transform the data and send the output to the memory stream
142                 xslt.Transform(inputData, null, writer);
143                 ms.Position = 0;
144                 return ms;
145             }
146         }
147 
GetOutput(Type type)148         public override object GetOutput(Type type)
149         {
150             if (type != typeof(Stream) && !type.IsSubclassOf(typeof(Stream)))
151                 throw new ArgumentException(SR.Cryptography_Xml_TransformIncorrectInputType, nameof(type));
152             return (Stream)GetOutput();
153         }
154     }
155 }
156