import base64 from Crypto.Hash import SHA, SHA256, RIPEMD from Crypto.Cipher import AES, DES3 from M2Crypto import EVP, RSA, util, BIO, X509, m2 from xml.dom import minidom from xml.dom.ext import c14n from xml import xpath from urlparse import urlparse class XMLSecurityKey: TRIPLEDES_CBC = 'http://www.w3.org/2001/04/xmlenc#tripledes-cbc' AES128_CBC = 'http://www.w3.org/2001/04/xmlenc#aes128-cbc' AES192_CBC = 'http://www.w3.org/2001/04/xmlenc#aes192-cbc' AES256_CBC = 'http://www.w3.org/2001/04/xmlenc#aes256-cbc' RSA_1_5 = 'http://www.w3.org/2001/04/xmlenc#rsa-1_5' RSA_OAEP_MGF1P = 'http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p' RSA_SHA1 = 'http://www.w3.org/2000/09/xmldsig#rsa-sha1' DSA_SHA1 = 'http://www.w3.org/2000/09/xmldsig#dsa-sha1' def __init__(self, type, params=None): self.type = 0 self.key = None self.passphrase = "" self.iv = None self.name = None self.keyChain = None self.isEncrypted = False self.encryptedCtx = None self.cryptParams = {'mode':None, 'cipher':None, 'padding':None, 'type':None, 'hash':None} if (type == XMLSecurityKey.TRIPLEDES_CBC): self.cryptParams['library'] = 'mcrypt' self.cryptParams['cipher'] = DES3 self.cryptParams['mode'] = DES3.MODE_CBC self.cryptParams['method'] = 'http://www.w3.org/2001/04/xmlenc#tripledes-cbc' elif (type == XMLSecurityKey.AES128_CBC): self.cryptParams['library'] = 'mcrypt' self.cryptParams['cipher'] = AES self.cryptParams['mode'] = AES.MODE_CBC self.cryptParams['method'] = 'http://www.w3.org/2001/04/xmlenc#aes128-cbc' self.cryptParams['type'] = 'AES' elif (type == XMLSecurityKey.AES192_CBC): self.cryptParams['library'] = 'mcrypt' self.cryptParams['cipher'] = AES self.cryptParams['mode'] = AES.MODE_CBC self.cryptParams['method'] = 'http://www.w3.org/2001/04/xmlenc#aes192-cbc' self.cryptParams['type'] = 'AES' elif (type == XMLSecurityKey.AES256_CBC): self.cryptParams['library'] = 'mcrypt' self.cryptParams['cipher'] = AES self.cryptParams['mode'] = AES.MODE_CBC self.cryptParams['method'] = 'http://www.w3.org/2001/04/xmlenc#aes256-cbc' self.cryptParams['type'] = 'AES' elif (type == XMLSecurityKey.RSA_1_5): self.cryptParams['library'] = 'openssl' self.cryptParams['padding'] = RSA.pkcs1_padding self.cryptParams['method'] = 'http://www.w3.org/2001/04/xmlenc#rsa-1_5' if isinstance(params, dict): if (params['type'] == 'public' or params['type'] == 'private'): self.cryptParams['type'] = params['type'] elif (type == XMLSecurityKey.RSA_OAEP_MGF1P): self.cryptParams['library'] = 'openssl'; self.cryptParams['padding'] = RSA.pkcs1_oaep_padding; self.cryptParams['method'] = 'http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p'; self.cryptParams['hash'] = None; if isinstance(params, dict): if (params['type'] == 'public' or params['type'] == 'private'): self.cryptParams['type'] = params['type'] elif (type == XMLSecurityKey.RSA_SHA1): self.cryptParams['library'] = 'openssl' self.cryptParams['method'] = 'http://www.w3.org/2000/09/xmldsig#rsa-sha1' if isinstance(params, dict): if (params['type'] == 'public' or params['type'] == 'private'): self.cryptParams['type'] = params['type'] else: return None self.type = type return None # def generateSessionKey(self): # key = '' # if (len(self.cryptParams['cipher']) > 0 and len(self.cryptParams['mode']) > 0): # keysize = mcrypt_module_get_algo_key_size(self.cryptParams['cipher']) # # Generating random key using iv generation routines # if ((keysize > 0) and (td = mcrypt_module_open(MCRYPT_RIJNDAEL_256, '',self.cryptParams['mode'], ''))): # if (self.cryptParams['type'] == 'AES'): # keysize = 16 # if self.type == XMLSecurityKey.AES256_CBC: # keysize = 32 # elif self.type == XMLSecurityKey.AES192_CBC: # keysize = 24 # while (len(key) < keysize): # key += mcrypt_create_iv(mcrypt_enc_get_iv_size (td),MCRYPT_RAND); # mcrypt_module_close(td) # key = substr(key, 0, keysize) # self.key = key # return key def loadKey(self, key, isFile=False, isCert = True): if isFile: fp = open(key, 'rb') key = fp.read() fp.close() if (self.cryptParams['library'] == 'openssl'): if (self.cryptParams['type'] == 'public'): if isCert: self.key = X509.load_cert_string(key).get_pubkey() else: bio = BIO.MemoryBuffer(key) self.key = EVP.PKey() self.key.assign_rsa(RSA.load_pub_key_bio(bio)) else: if isCert: self.key = RSA.load_key_string(key) else: self.key = RSA.load_key_string(key, None) #self.key = EVP.load_key_string(key, self.passphrase) elif (self.cryptParams['type'] == 'AES'): # Check key length if (self.type == XMLSecurityKey.AES256_CBC): if (len(key) < 25): #throw new Exception('Key must contain at least 25 characters for this cipher'); return None elif (self.type == XMLSecurityKey.AES192_CBC): if (len(key) < 17): #throw new Exception('Key must contain at least 17 characters for this cipher') return None self.key = key return None #TODO: Convert encryptMcrypt def encryptMcrypt(self, data): td = mcrypt_module_open(self.cryptParams['cipher'], '', self.cryptParams['mode'], '') self.iv = mcrypt_create_iv (mcrypt_enc_get_iv_size(td), MCRYPT_RAND) mcrypt_generic_init(td, self.key, self.iv) encrypted_data = self.iv.mcrypt_generic(td, data) mcrypt_generic_deinit(td) mcrypt_module_close(td) return encrypted_data def decryptMcrypt(self, data): dCipher = self.cryptParams['cipher'] iv_length = dCipher.block_size self.iv = data[0:iv_length] data = data[iv_length:len(data)] #if (self.cryptParams['mode'] == MCRYPT_MODE_CBC): # For now only CBC supported so remove test pad = len(data) % iv_length data += '0'*(iv_length - pad) #End CBC only code td = dCipher.new(self.key, self.cryptParams['mode'], self.iv) decrypted_data = td.decrypt(data) if (not decrypted_data): return None #if (self.cryptParams['mode'] == MCRYPT_MODE_CBC): # For now only CBC supported so remove test #first remove any padding we have added dataLen = len(decrypted_data) newdataLen = dataLen-iv_length+pad decrypted_data = decrypted_data[0:newdataLen] #remove any CBC padding paddingLength = decrypted_data[newdataLen-1:newdataLen] decrypted_data = decrypted_data[0:newdataLen - ord(paddingLength)] #End CBC only code return decrypted_data #TODO: Convert encryptOpenSSL def encryptOpenSSL(self, data): if (self.cryptParams['type'] == 'public'): if (not openssl_public_encrypt(data, encrypted_data, self.key, self.cryptParams['padding'])): # throw new Exception('Failure encrypting Data') return None else: if (not openssl_private_encrypt(data, encrypted_data, self.key, self.cryptParams['padding'])): # throw new Exception('Failure encrypting Data'); return None return encrypted_data def decryptOpenSSL(self, data): if (self.cryptParams['type'] == 'public'): decrypted = self.key.public_decrypt(data, self.cryptParams['padding']) if (not decrypted): # throw new Exception('Failure decrypting Data') print 'Failure decrypting Data' return None else: decrypted = self.key.private_decrypt(data, self.cryptParams['padding']) if (not decrypted): # throw new Exception('Failure decrypting Data') print 'Failure decrypting Data' return None return decrypted def signOpenSSL(self, data): pKey = self.key pKey.sign_init() pKey.sign_update(data) return pKey.sign_final() def verifyOpenSSL(self, data, signature): pKey = self.key pKey.verify_init() pKey.verify_update(data) return m2.verify_final(pKey.ctx, signature, pKey.pkey) def encryptData(self, data): if (self.cryptParams['library'] == 'mcrypt'): return self.encryptMcrypt(data) elif (self.cryptParams['library'] == 'openssl'): return self.encryptOpenSSL(data) else: return None def decryptData(self, data): if (self.cryptParams['library'] == 'mcrypt'): return self.decryptMcrypt(data) elif (self.cryptParams['library'] == 'openssl'): return self.decryptOpenSSL(data) else: return None def signData(self, data): if (self.cryptParams['library'] == 'openssl'): return self.signOpenSSL(data) return None def verifySignature(self, data, signature): if (self.cryptParams['library'] == 'openssl'): return self.verifyOpenSSL(data, signature) def getAlgorith(self): return self.cryptParams['method'] def makeAsnSegment(type, instring): if (type == 0x02): if (ord(instring[0]) > 0x7f): instring = chr(0)+instring elif (type == 0x03): instring = chr(0)+instring length = len(instring) if (length < 128): output = "%c%c%s" % (type, length, instring) elif (length < 0x0100): output = "%c%c%c%s" % (type, 0x81, length, instring) elif (length < 0x010000): output = "%c%c%c%c%s" % (type, 0x82, length/0x0100, length%0x0100, instring); else: output = None return output # Modulus and Exponent must already be base64 decoded def convertRSA(modulus, exponent): # make an ASN publicKeyInfo exponentEncoding = XMLSecurityKey.makeAsnSegment(0x02, exponent) modulusEncoding = XMLSecurityKey.makeAsnSegment(0x02, modulus) sequenceEncoding = XMLSecurityKey. makeAsnSegment(0x30, modulusEncoding+exponentEncoding) bitstringEncoding = XMLSecurityKey.makeAsnSegment(0x03, sequenceEncoding) rsaAlgorithmIdentifier = util.h2b("300D06092A864886F70D0101010500") publicKeyInfo = XMLSecurityKey.makeAsnSegment(0x30, rsaAlgorithmIdentifier + bitstringEncoding) # encode the publicKeyInfo in base64 and add PEM brackets publicKeyInfoBase64 = base64.b64encode(publicKeyInfo); encoding = "-----BEGIN PUBLIC KEY-----\n" offset = 0 segment= publicKeyInfoBase64[offset:offset+64] while(segment): encoding = encoding + segment + "\n" offset += 64 segment=publicKeyInfoBase64[offset: offset+64] return encoding + "-----END PUBLIC KEY-----\n" def serializeKey(self, parent): return None makeAsnSegment = staticmethod(makeAsnSegment) convertRSA = staticmethod(convertRSA) class XMLSecurityDSig: XMLDSIGNS = 'http://www.w3.org/2000/09/xmldsig#' SHA1 = 'http://www.w3.org/2000/09/xmldsig#sha1' SHA256 = 'http://www.w3.org/2001/04/xmlenc#sha256' SHA512 = 'http://www.w3.org/2001/04/xmlenc#sha512' RIPEMD160 = 'http://www.w3.org/2001/04/xmlenc#ripemd160' C14N = 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315' C14N_COMMENTS = 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315#WithComments' EXC_C14N = 'http://www.w3.org/2001/10/xml-exc-c14n#' EXC_C14N_COMMENTS = 'http://www.w3.org/2001/10/xml-exc-c14n#WithComments' template = '\ \ \ \ ' def __init__(self): self.idKeys = {} self.idNS = {} self.signedInfo = None self.xPathCtx = None self.canonicalMethod = None self.prefix = 'ds' self.searchpfx = 'secdsig' sigdoc = minidom.parseString(XMLSecurityDSig.template) self.sigNode = sigdoc.documentElement def getXPathObj(self): if (not self.xPathCtx and self.sigNode): xPath = xpath.CreateContext(self.sigNode) xPath.setNamespaces({'secdsig' : XMLSecurityDSig.XMLDSIGNS}) self.xPathCtx = xPath return self.xPathCtx # TODO: Convert generate_GUID def generate_GUID(prefix=None): uuid = md5(uniqid(rand(), True)) guid = prefix.substr(uuid,0,8)+"-" + \ substr(uuid,8,4)+"-" + \ substr(uuid,12,4)+"-" + \ substr(uuid,16,4)+"-" + \ substr(uuid,20,12) return guid def locateSignature(self, objDoc): if (objDoc.nodeType == objDoc.DOCUMENT_NODE): doc = objDoc else: doc = objDoc.ownerDocument if doc: xPath = xpath.CreateContext(doc) xPath.setNamespaces({'secdsig' : XMLSecurityDSig.XMLDSIGNS}) query = ".//secdsig:Signature" nodeset = xpath.Evaluate(query, contextNode=objDoc, context=xPath) self.sigNode = nodeset[0] return self.sigNode; return None def createNewSignNode(self, name, value=None): doc = self.sigNode.ownerDocument if (not is_null(value)): node = doc.newChild(XMLSecurityDSig.XMLDSIGNS, self.prefix+':'+name, value) else: node = doc.newChild(XMLSecurityDSig.XMLDSIGNS, self.prefix+':'+name) return node def setCanonicalMethod(self, method): if (method == 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315' or \ method == 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315#WithComments' or \ method == 'http://www.w3.org/2001/10/xml-exc-c14n#' or \ method == 'http://www.w3.org/2001/10/xml-exc-c14n#WithComments'): self.canonicalMethod = method return True else: return False xpath = self.getXPathObj() if (xpath): query = './' + self.searchpfx + ':SignedInfo' nodeset = xpath.xpathEval(query) sinfo = nodeset[0] if (sinfo): query = './' + self.searchpfx + 'CanonicalizationMethod' xpath.setContextNode(sinfo) nodeset = xpath.xpathEval(query) canonNode = nodeset-[0] if (not canonNode): canonNode = self.createNewSignNode('CanonicalizationMethod') sinfo.insertBefore(canonNode, sinfo.firstChild) canonNode.setAttribute('Algorithm', self.canonicalMethod) def canonicalizeData(self, node, canonicalmethod): exclusive = 0 withComments = 0 if (canonicalmethod == 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315'): exclusive = 0; withComments = 0; elif (canonicalmethod == 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315#WithComments'): withComments = 1; elif (canonicalmethod == 'http://www.w3.org/2001/10/xml-exc-c14n#'): exclusive = 1; elif (canonicalmethod == 'http://www.w3.org/2001/10/xml-exc-c14n#WithComments'): exclusive = 1; withComments = 1; if (exclusive == 1): return c14n.Canonicalize(node, None, comments=withComments, unsuppressedPrefixes=None) else: return c14n.Canonicalize(node, None, comments=withComments) def canonicalizeSignedInfo(self): doc = self.sigNode.ownerDocument canonicalmethod = None if (doc): xPath = self.getXPathObj() query = "./secdsig:SignedInfo" nodeset = xpath.Evaluate(query, contextNode=self.sigNode, context=xPath) signInfoNode = nodeset[0] if (signInfoNode): canonNode = signInfoNode.firstChild while (canonNode and (canonNode.localName != 'CanonicalizationMethod') and \ (canonNode.namespaceURI != XMLSecurityDSig.XMLDSIGNS)): canonNode = canonNode.nextSibling if (canonNode): canonicalmethod = canonNode.getAttribute('Algorithm') self.signedInfo = self.canonicalizeData(signInfoNode, canonicalmethod) return self.signedInfo return None def calculateDigest (self, digestAlgorithm, data): if (digestAlgorithm == XMLSecurityDSig.SHA1): #alg = 'sha1' alg = SHA.new() elif (digestAlgorithm == XMLSecurityDSig.SHA256): #alg = 'sha256' alg = SHA256.new() elif (digestAlgorithm == XMLSecurityDSig.SHA512): #not implemented yet #alg = 'sha512' return None elif (digestAlgorithm == XMLSecurityDSig.RIPEMD160): #alg = 'ripemd160' alg = RIPEMD.new() else: return None alg.update(data) hashed = alg.digest() return base64.b64encode(hashed) def validateDigest(self, refNode, data): xPath = xpath.CreateContext(refNode) xPath.setNamespaces({'secdsig' : XMLSecurityDSig.XMLDSIGNS}) query = 'string(./secdsig:DigestMethod/@Algorithm)' digestAlgorithm = xpath.Evaluate(query, contextNode=refNode, context=xPath) digValue = self.calculateDigest(digestAlgorithm, data) query = 'string(./secdsig:DigestValue)' digestValue = xpath.Evaluate(query, contextNode=refNode, context=xPath) return (digValue == digestValue) def processTransforms(self, refNode, objData): data = objData xPath = xpath.CreateContext(refNode) xPath.setNamespaces({'secdsig' : XMLSecurityDSig.XMLDSIGNS}) query = './secdsig:Transforms/secdsig:Transform' nodelist = xpath.Evaluate(query, contextNode=refNode, context=xPath) canonicalMethod = 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315' for i in range(len(nodelist)): transform = nodelist[i] algorithm = transform.getAttribute("Algorithm") if (algorithm == 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315' or \ algorithm == 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315#WithComments' or \ algorithm == 'http://www.w3.org/2001/10/xml-exc-c14n#' or \ algorithm == 'http://www.w3.org/2001/10/xml-exc-c14n#WithComments'): canonicalMethod = algorithm break if (isinstance(data, minidom.Node)): data = self.canonicalizeData(objData, canonicalMethod) return data def processRefNode(self, refNode): dataObject = None uri = refNode.getAttribute("URI") if (uri): arUrl = urlparse(uri) if (not arUrl[2]): identifier = arUrl[5] if (identifier): if (self.idNS and isinstance(seld.idNS, dict)): for nspf in self.idNS.keys(): ns = self.idNS[nspf] xPath.xpathRegisterNs(nspf, ns) iDlist = '@Id="' + identifier + '"' if isinstance(self.idKeys, dict): for key in self.idKeys.keys(): idKey = self.idKeys[key] iDlist = iDlist + " or @" + idKey + "='" + identifier + "'" query = '//*[' + iDlist + ']' res = xpath.Evaluate(query, contextNode=refNode.ownerDocument) dataObject = res[0] else: dataObject = refNode.ownerDocument else: dataObject = file_get_contents(arUrl) else: dataObject = refNode.ownerDocument data = self.processTransforms(refNode, dataObject) return self.validateDigest(refNode, data) def validateReference(self): doc = self.sigNode.ownerDocument if (doc != self.sigNode): self.sigNode.parentNode.removeChild(self.sigNode) xPath = self.getXPathObj() query = "./secdsig:SignedInfo/secdsig:Reference" nodeset = xpath.Evaluate(query, contextNode=self.sigNode, context=xPath) if (not nodeset): return None for i in range(len(nodeset)): refNode = nodeset[i] if (not self.processRefNode(refNode)): return None return True def addRefInternal(self, sinfoNode, node, algorithm, arTransforms=None, options=None): prefix = None prefix_ns = None if isinstance(options, dict): prefix = options['prefix'] prefix_ns = options['prefix_ns'] id_name = options['id_name'] refNode = self.createNewSignNode('Reference') sinfoNode.appendChild(refNode) if (node.nodeType == node.DOCUMENT_NODE): uri = None else: # Do wer really need to set a prefix? uri = XMLSecurityDSig.generate_GUID() refNode.setAttribute("URI", '#' + uri) transNodes = self.createNewSignNode('Transforms') refNode.appendChild(transNodes) if isinstance(arTransforms, dict): for i in range(len(arTransforms)): transform = nodelist[i] transNode = self.createNewSignNode('Transform') transNodes.appendChild(transNode) transNode.setAttribute('Algorithm', transform) elif (not empty(self.canonicalMethod)): transNode = self.createNewSignNode('Transform') transNodes.appendChild(transNode) transNode.setAttribute('Algorithm', self.canonicalMethod) if (uri): attname = id_name if (prefix): attname = prefix + ':' + attname node.setAttributeNS(prefix_ns, attname, uri) canonicalData = self.processTransforms(refNode, node); digValue = self.calculateDigest(algorithm, canonicalData); digestMethod = self.createNewSignNode('DigestMethod'); refNode.appendChild(digestMethod); digestMethod.setAttribute('Algorithm', algorithm); digestValue = self.createNewSignNode('DigestValue', digValue) refNode.appendChild(digestValue) def addReference(self, node, algorithm, arTransforms=None, options=None): xpath = self.getXPathObj() if (xpath): query = "./secdsig:SignedInfo" nodeset = xpath.xpathEval(query) sInfo = nodeset[0] if (sInfo): self.addRefInternal(sInfo, node, algorithm, arTransforms, options) def addReferenceList(self, arNodes, algorithm, arTransforms=None, options=None): xpath = self.getXPathObj() if (xpath): query = "./secdsig:SignedInfo" nodeset = xpath.xpathEval(query) sInfo = nodeset[0] if (sInfo): for i in range(len(sInfo)): node = sInfo[i] self.addRefInternal(sInfo, node, algorithm, arTransforms, options) def locateKey(self, node=None): if (not node): node = self.sigNode if (not isinstance(node, minidom.Node)): return None doc = node.ownerDocument if (doc): xPath = xpath.CreateContext(node) xPath.setNamespaces({'secdsig' : XMLSecurityDSig.XMLDSIGNS}) query = "string(./secdsig:SignedInfo/secdsig:SignatureMethod/@Algorithm)"; algorithm = xpath.Evaluate(query, contextNode=node, context=xPath) if (algorithm): # try { objKey = XMLSecurityKey(algorithm,{'type':'public'}) # } catch (Exception e) { # return NULL; # } return objKey return None def verify(self, objKey): xPath = xpath.CreateContext(self.sigNode) xPath.setNamespaces({'secdsig' : XMLSecurityDSig.XMLDSIGNS}) query = "string(./secdsig:SignatureValue)" sigValue = xpath.Evaluate(query, contextNode=self.sigNode, context=xPath) if (not sigValue): return None return objKey.verifySignature(self.signedInfo, base64.b64decode(sigValue)) def signData(self, objKey, data): return objKey.signData(data) def sign(self, objKey): xpath = self.getXPathObj() if (xpath): query = "./secdsig:SignedInfo"; nodeset = xpath.xpathEval(query) sInfo = nodeset[0] if (sInfo): query = "./secdsig:SignatureMethod" nodeset = xpath.xpathEval(query) sMethod = nodeset[0] sMethod.setAttribute('Algorithm', objKey.type) data = self.canonicalizeData(sInfo, self.canonicalMethod) sigValue = base64.b64encode(self.signData(objKey, data)) sigValueNode = self.createNewSignNode('SignatureValue', sigValue) infoSibling = sInfo.nextSibling if (infoSibling): infoSibling.parent.insertBefore(sigValueNode, infoSibling) else: self.sigNode.appendChild(sigValueNode) def appendCert(self): return None def appendKey(self, objKey, parent=None): objKey.serializeKey(parent) def appendSignature(self, parent, insertBefore = False): if (parent.nodeType == parent.DOCUMENT_NODE): baseDoc = parent else: parent.ownerDocument newSig = baseDoc.importNode(self.sigNode, True) if (insertBefore): parent.insertBefore(newSig, parent.firstChild) else: parent.appendChild(newSig) # TODO: Convert get509XCert def get509XCert(cert, isPEMFormat=True): if (isPEMFormat): data = '' arCert = explode("\n", cert) inData = False for i in range(len(arCert)): curData = arCert[i] if (not inData) : if (strncmp(curData, '-----BEGIN CERTIFICATE', 22) == 0): inData = True else: if (strncmp(curData, '-----END CERTIFICATE', 20) == 0): break data += trim(curData) else: data = cert return data def add509Cert(self, cert, isPEMFormat=True): data = XMLSecurityDSig.get509XCert(cert, isPEMFormat) xpath = self.getXPathObj() if (xpath): query = "./secdsig:KeyInfo" nodeset = xpath.xpathEval(query) keyInfo = nodeset[0] if (not keyInfo): inserted = False keyInfo = self.createNewSignNode('KeyInfo') query = "./secdsig:Object" nodeset = xpath.xpathEval(query) sObject = nodeset[0] if (sObject): sObject.parent.insertBefore(keyInfo, sObject) inserted = True if (not inserted): self.sigNode.appendChild(keyInfo) x509DataNode = self.createNewSignNode('X509Data') keyInfo.appendChild(x509DataNode) x509CertNode = self.createNewSignNode('X509Certificate', data) x509DataNode.appendChild(x509CertNode) class XMLSecEnc: Element = 'http://www.w3.org/2001/04/xmlenc#Element' Content = 'http://www.w3.org/2001/04/xmlenc#Content' URI = 3 XMLENCNS = 'http://www.w3.org/2001/04/xmlenc#' template = "\ \ \ \ " def __init__(self): self.rawNode = None self.type = None self.encdoc = minidom.parseString(XMLSecEnc.template) def setNode(self, node): self.rawNode = node def encryptNode(self, objKey, replace=True): data = '' if (not self.rawNode): # throw new Exception('Node to encrypt has not been set') return None doc = self.rawNode.ownerDocument xPath = self.encdoc.xpathNewContext() objList = xPath.xpathEval('/xenc:EncryptedData/xenc:CipherData/xenc:CipherValue') cipherValue = objList[0] xPath.xpathFreeContext() if (cipherValue == None): #throw new Exception('Error locating CipherValue element within template') return None if (self.type == XMLSecEnc.Element): data = doc.saveXML(self.rawNode) self.encdoc.getRootElement().setAttribute('Type', XMLSecEnc.Element) elif (self.type == XMLSecEnc.Content): children = self.sawNode.childNodes for i in range(len(children)): data += doc.saveXML(children[i]) self.encdoc.getRootElement().setAttribute('Type', XMLSecEnc.Content) else: #throw new Exception('Type is currently not supported') return None encMethod = self.encdoc.getRootElement().appendChild(self.encdoc.createElementNS(XMLSecEnc.XMLENCNS, 'xenc:EncryptionMethod')) encMethod.setAttribute('Algorithm', objKey.getAlgorith()) cipherValue.parent.parent.insertBefore(encMethod, cipherValue.parent) strEncrypt = base64.b64encode(objKey.encryptData(data)) value = self.encdoc.createTextNode(strEncrypt) cipherValue.appendChild(value) if (replace): if (self.type == XMLSecEnc.Element): if (self.rawNode.nodeType == self.rawNode.DOCUMENT_NODE): return self.encdoc importEnc = self.rawNode.ownerDocument.importNode(self.encdoc.getRootElement(), True) self.rawNode.parent.replaceChild(importEnc, self.rawNode) return importEnc elif (self.type == XMLSecEnc.Content): importEnc = self.rawNode.ownerDocument.importNode(self.encdoc.getRootElement(), True) while(self.rawNode.firstChild): self.rawNode.firstChild.unlinkNode() self.rawNode.appendChild(importEnc) def decryptNode(self, objKey, replace=True): data = '' if (not self.rawNode): #throw new Exception('Node to decrypt has not been set') print 'Node to decrypt has not been set' return None doc = self.rawNode.ownerDocument xPath = xpath.CreateContext(self.rawNode) xPath.setNamespaces({'xmlencr' : XMLSecEnc.XMLENCNS}) query = "string(./xmlencr:CipherData/xmlencr:CipherValue)" encryptedData = xpath.Evaluate(query, context=xPath) if (encryptedData): encryptedData = base64.b64decode(encryptedData) decrypted = objKey.decryptData(encryptedData) if (replace): if (self.type == XMLSecEnc.Element): newdoc = minidom.parseString(decrypted) if (self.rawNode.nodeType == newdoc.DOCUMENT_NODE): return newdoc importEnc = self.rawNode.ownerDocument.importNode(newdoc.getRootElement(), True) self.rawNode.parent.replaceChild(importEnc, self.rawNode) return importEnc elif (self.type == XMLSecEnc.Content): if (self.rawNode.nodeType == self.rawNode.DOCUMENT_NODE): doc = self.rawNode else: doc = self.rawNode.ownerDocument newFrag = doc.createDOMDocumentFragment() newFrag.appendXML(decrypted) self.rawNode.parent.replaceChild(newFrag, self.rawNode) return self.rawNode.parent return decrypted else: #throw new Exception("Cannot locate encrypted data") return None def encryptKey(self, srcKey, rawKey, append=True): strEncKey = base64.b64encode(srcKey.encryptData(rawKey.key)) root = self.encdoc.getRootElement() encKey = self.encdoc.createElementNS(XMLSecEnc.XMLENCNS, 'xenc:EncryptedKey') if (append): keyInfo = root.appendChild(self.encdoc.createElementNS('http://www.w3.org/2000/09/xmldsig#', 'dsig:KeyInfo')) keyInfo.appendChild(encKey) encMethod = encKey.appendChild(self.encdoc.createElementNS(XMLSecEnc.XMLENCNS, 'xenc:EncryptionMethod')) encMethod.setAttribute('Algorithm', srcKey.getAlgorith()) if (srcKey.name): keyInfo = encKey.appendChild(self.encdoc.createElementNS('http://www.w3.org/2000/09/xmldsig#', 'dsig:KeyInfo')) keyInfo.appendChild(self.encdoc.createElementNS('http://www.w3.org/2000/09/xmldsig#', 'dsig:KeyName', srcKey.name)) cipherData = encKey.appendChild(self.encdoc.createElementNS(XMLSecEnc.XMLENCNS, 'xenc:CipherData')) cipherData.appendChild(self.encdoc.createElementNS(XMLSecEnc.XMLENCNS, 'xenc:CipherValue', strEncKey)) return None def decryptKey(self, encKey): if (not encKey.isEncrypted): #throw new Exception("Key is not Encrypted") return None if (not encKey.key): #throw new Exception("Key is missing data to perform the decryption") return None return self.decryptNode(encKey, False) def locateEncryptedData(self, element): if (element.nodeType == element.DOCUMENT_NODE): doc = element else: doc = element.ownerDocument if (doc): query = "//*[local-name()='EncryptedData' and namespace-uri()='" + XMLSecEnc.XMLENCNS + "']" nodeset = xpath.Evaluate(query, doc) return nodeset[0] return None def locateKey(self, node=None): if (not node): node = self.rawNode if (not isinstance(node,minidom.Node)): return None doc = node.ownerDocument if (doc): xPath = xpath.CreateContext(node) xPath.setNamespaces({'xmlsecenc' : XMLSecEnc.XMLENCNS}) query = ".//xmlsecenc:EncryptionMethod"; nodeset = xpath.Evaluate(query, contextNode=node, context=xPath) encmeth = nodeset[0] if (encmeth): attrAlgorithm = encmeth.getAttribute("Algorithm") # try { objKey = XMLSecurityKey(attrAlgorithm, {'type':'private'}) # } catch (Exception e) { # return NULL; # } return objKey return None def staticLocateKeyInfo(objBaseKey=None, node=None): if (not isinstance(node, minidom.Node)): return None doc = node.ownerDocument if (doc): xPath = xpath.CreateContext(node) xPath.setNamespaces({'xmlsecdsig' : XMLSecurityDSig.XMLDSIGNS, 'xmlsecenc' : XMLSecEnc.XMLENCNS}) query = "./xmlsecdsig:KeyInfo"; nodeset = xpath.Evaluate(query, contextNode=node, context=xPath) encmeth = nodeset[0] if (encmeth): for child in encmeth.childNodes: if (child.localName == 'KeyName'): if (objBaseKey): objBaseKey.localName = child.firstChild.data elif (child.localName == 'KeyValue'): for keyval in child.childNodes: if (keyval.localName == 'DSAKeyValue'): #throw new Exception("DSAKeyValue currently not supported") return None elif (keyval.localName == 'RSAKeyValue'): modulus = None exponent = None modulusNode = keyval.getElementsByTagNameNS(XMLSecurityDSig.XMLDSIGNS, 'Modulus') if (modulusNode): modulus = base64.b64decode(modulusNode[0].firstChild.data) exponentNode = keyval.getElementsByTagNameNS(XMLSecurityDSig.XMLDSIGNS, 'Exponent') if (exponentNode): exponent = base64.b64decode(exponentNode[0].firstChild.data) if ((not modulus) or (not exponent)): #throw new Exception("Missing Modulus or Exponent") return None publicKey = XMLSecurityKey.convertRSA(modulus, exponent) objBaseKey.loadKey(publicKey) elif (child.localName == 'RetrievalMethod'): # Not currently supported #break; nop = 1 elif (child.localName == 'EncryptedKey'): objenc = XMLSecEnc() objenc.setNode(child) objKey = objenc.locateKey() if (not objKey): #throw new Exception("Unable to locate algorithm for this Encrypted Key") return None objKey.isEncrypted = True; objKey.encryptedCtx = objenc XMLSecEnc.staticLocateKeyInfo(objKey, child) return objKey elif (child.localName == 'X509Data'): x509certNodes = child.getElementsByTagName('X509Certificate') if (x509certNodes): if (len(x509certNodes) > 0): x509cert = x509certNodes[0].textContent x509cert = str_replace(array("\r", "\n"), "", x509cert) x509cert = "-----BEGIN CERTIFICATE-----\n" + chunk_split(x509cert, 64, "\n") + "-----END CERTIFICATE-----\n" objBaseKey.loadKey(x509cert) return objBaseKey return None def locateKeyInfo(self, objBaseKey=None, node=None): if (not node): node = self.rawNode return XMLSecEnc.staticLocateKeyInfo(objBaseKey, node) staticLocateKeyInfo = staticmethod(staticLocateKeyInfo)