Refactor data parsers

This commit is contained in:
Amir Abbas 2020-03-29 12:55:15 +04:30
parent 3ab86ac439
commit 4394c4db8e
6 changed files with 95 additions and 99 deletions

View File

@ -14,13 +14,13 @@ import SMB2.Raw
/// Provides synchronous operation on SMB2
final class SMB2Context: CustomDebugStringConvertible, CustomReflectable {
var context: UnsafeMutablePointer<smb2_context>?
var unsafe: UnsafeMutablePointer<smb2_context>?
private var _context_lock = NSRecursiveLock()
var timeout: TimeInterval
init(timeout: TimeInterval) throws {
let _context = try smb2_init_context().unwrap()
self.context = _context
self.unsafe = _context
self.timeout = timeout
}
@ -29,7 +29,7 @@ final class SMB2Context: CustomDebugStringConvertible, CustomReflectable {
try? self.disconnect()
}
try? withThreadSafeContext { (context) in
self.context = nil
self.unsafe = nil
smb2_destroy_context(context)
}
}
@ -39,7 +39,7 @@ final class SMB2Context: CustomDebugStringConvertible, CustomReflectable {
defer {
_context_lock.unlock()
}
return try handler(context.unwrap())
return try handler(unsafe.unwrap())
}
public var debugDescription: String {
@ -48,7 +48,7 @@ final class SMB2Context: CustomDebugStringConvertible, CustomReflectable {
public var customMirror: Mirror {
var c: [(label: String?, value: Any)] = []
if self.context != nil {
if self.unsafe != nil {
c.append((label: "server", value: server!))
c.append((label: "securityMode", value: securityMode))
c.append((label: "authentication", value: authentication))
@ -68,7 +68,7 @@ final class SMB2Context: CustomDebugStringConvertible, CustomReflectable {
extension SMB2Context {
var workstation: String {
get {
return (context?.pointee.workstation).map(String.init(cString:)) ?? ""
return (unsafe?.pointee.workstation).map(String.init(cString:)) ?? ""
}
set {
try? withThreadSafeContext { (context) in
@ -79,7 +79,7 @@ extension SMB2Context {
var domain: String {
get {
return (context?.pointee.domain).map(String.init(cString:)) ?? ""
return (unsafe?.pointee.domain).map(String.init(cString:)) ?? ""
}
set {
try? withThreadSafeContext { (context) in
@ -90,7 +90,7 @@ extension SMB2Context {
var user: String {
get {
return (context?.pointee.user).map(String.init(cString:)) ?? ""
return (unsafe?.pointee.user).map(String.init(cString:)) ?? ""
}
set {
try? withThreadSafeContext { (context) in
@ -101,7 +101,7 @@ extension SMB2Context {
var password: String {
get {
return (context?.pointee.password).map(String.init(cString:)) ?? ""
return (unsafe?.pointee.password).map(String.init(cString:)) ?? ""
}
set {
try? withThreadSafeContext { (context) in
@ -112,7 +112,7 @@ extension SMB2Context {
var securityMode: NegotiateSigning {
get {
return (context?.pointee.security_mode).flatMap(NegotiateSigning.init(rawValue:)) ?? []
return (unsafe?.pointee.security_mode).flatMap(NegotiateSigning.init(rawValue:)) ?? []
}
set {
try? withThreadSafeContext { (context) in
@ -123,7 +123,7 @@ extension SMB2Context {
var seal: Bool {
get {
return context?.pointee.seal ?? 0 != 0
return unsafe?.pointee.seal ?? 0 != 0
}
set {
try? withThreadSafeContext { (context) in
@ -134,7 +134,7 @@ extension SMB2Context {
var authentication: Security {
get {
return context?.pointee.sec ?? SMB2_SEC_UNDEFINED
return unsafe?.pointee.sec ?? SMB2_SEC_UNDEFINED
}
set {
try? withThreadSafeContext { (context) in
@ -144,7 +144,7 @@ extension SMB2Context {
}
var clientGuid: UUID? {
guard let guid = try? smb2_get_client_guid(context.unwrap()) else {
guard let guid = try? smb2_get_client_guid(unsafe.unwrap()) else {
return nil
}
let uuid = UnsafeRawPointer(guid).assumingMemoryBound(to: uuid_t.self).pointee
@ -152,15 +152,15 @@ extension SMB2Context {
}
var server: String? {
return context?.pointee.server.map(String.init(cString:))
return unsafe?.pointee.server.map(String.init(cString:))
}
var share: String? {
return context?.pointee.share.map(String.init(cString:))
return unsafe?.pointee.share.map(String.init(cString:))
}
var version: Version {
return (context?.pointee.dialect).map { Version(rawValue: UInt32($0)) } ?? .any
return (unsafe?.pointee.dialect).map { Version(rawValue: UInt32($0)) } ?? .any
}
var isConnected: Bool {
@ -174,27 +174,25 @@ extension SMB2Context {
}
var fileDescriptor: Int32 {
return (try? smb2_get_fd(context.unwrap())) ?? -1
return (try? smb2_get_fd(unsafe.unwrap())) ?? -1
}
var error: String? {
let errorStr = smb2_get_error(context)
let errorStr = smb2_get_error(unsafe)
return errorStr.map(String.init(cString:))
}
func whichEvents() throws -> Int16 {
return try Int16(truncatingIfNeeded: smb2_which_events(context.unwrap()))
return try Int16(truncatingIfNeeded: smb2_which_events(unsafe.unwrap()))
}
func service(revents: Int32) throws {
try withThreadSafeContext { (context) in
let result = smb2_service(context, revents)
if result < 0 {
self.context = nil
smb2_destroy_context(context)
}
try POSIXError.throwIfError(result, description: error)
let result = smb2_service(unsafe, revents)
if result < 0 {
self.unsafe = nil
smb2_destroy_context(unsafe)
}
try POSIXError.throwIfError(result, description: error)
}
}
@ -228,7 +226,7 @@ extension SMB2Context {
// MARK: DCE-RPC
extension SMB2Context {
func shareEnum() throws -> [SMB2Share] {
return try async_await(dataHandler: Parser.toSMB2Shares) { (context, cbPtr) -> Int32 in
return try async_await(dataHandler: [SMB2Share].init) { (context, cbPtr) -> Int32 in
smb2_share_enum_async(context, SMB2Context.generic_handler, cbPtr)
}.data
}
@ -236,7 +234,7 @@ extension SMB2Context {
func shareEnumSwift() throws -> [SMB2Share]
{
// Connection to server service.
let srvsvc = try SMB2FileHandle(path: "srvsvc", on: self)
let srvsvc = try SMB2FileHandle.using(path: "srvsvc", on: self)
// Bind command
_ = try srvsvc.write(data: MSRPC.srvsvcBindData())
let recvBindData = try srvsvc.pread(offset: 0, length: Int(Int16.max))
@ -268,7 +266,7 @@ extension SMB2Context {
}
func readlink(_ path: String) throws -> String {
return try async_await(dataHandler: Parser.toString) { (context, cbPtr) -> Int32 in
return try async_await(dataHandler: String.init) { (context, cbPtr) -> Int32 in
smb2_readlink_async(context, path, SMB2Context.generic_handler, cbPtr)
}.data
}
@ -352,16 +350,17 @@ extension SMB2Context {
} catch { }
}
typealias ContextHandler<R> = (_ context: UnsafeMutablePointer<smb2_context>, _ ptr: UnsafeMutableRawPointer?) throws -> R
typealias ContextHandler<R> = (_ context: SMB2Context, _ dataPtr: UnsafeMutableRawPointer?) throws -> R
typealias UnsafeContextHandler<R> = (_ context: UnsafeMutablePointer<smb2_context>, _ dataPtr: UnsafeMutableRawPointer?) throws -> R
@discardableResult
func async_await(execute handler: ContextHandler<Int32>) throws -> Int32
func async_await(execute handler: UnsafeContextHandler<Int32>) throws -> Int32
{
return try async_await(dataHandler: Parser.toVoid, execute: handler).result
return try async_await(dataHandler: { _, _ in }, execute: handler).result
}
@discardableResult
func async_await<DataType>(dataHandler: @escaping ContextHandler<DataType>, execute handler: ContextHandler<Int32>)
func async_await<DataType>(dataHandler: @escaping ContextHandler<DataType>, execute handler: UnsafeContextHandler<Int32>)
throws -> (result: Int32, data: DataType)
{
return try withThreadSafeContext { (context) -> (Int32, DataType) in
@ -370,7 +369,7 @@ extension SMB2Context {
var dataHandlerError: Error?
cb.dataHandler = { ptr in
do {
resultData = try dataHandler(context, ptr)
resultData = try dataHandler(self, ptr)
} catch {
dataHandlerError = error
}
@ -387,13 +386,13 @@ extension SMB2Context {
}
@discardableResult
func async_await_pdu(execute handler: ContextHandler<UnsafeMutablePointer<smb2_pdu>?>) throws -> UInt32
func async_await_pdu(execute handler: UnsafeContextHandler<UnsafeMutablePointer<smb2_pdu>?>) throws -> UInt32
{
return try async_await_pdu(dataHandler: Parser.toVoid, execute: handler).status
return try async_await_pdu(dataHandler: { _, _ in }, execute: handler).status
}
@discardableResult
func async_await_pdu<DataType>(dataHandler: @escaping ContextHandler<DataType>, execute handler: ContextHandler<UnsafeMutablePointer<smb2_pdu>?>)
func async_await_pdu<DataType>(dataHandler: @escaping ContextHandler<DataType>, execute handler: UnsafeContextHandler<UnsafeMutablePointer<smb2_pdu>?>)
throws -> (status: UInt32, data: DataType)
{
return try withThreadSafeContext { (context) -> (UInt32, DataType) in
@ -402,7 +401,7 @@ extension SMB2Context {
var dataHandlerError: Error?
cb.dataHandler = { ptr in
do {
resultData = try dataHandler(context, ptr)
resultData = try dataHandler(self, ptr)
} catch {
dataHandlerError = error
}

View File

@ -17,7 +17,7 @@ final class SMB2Directory: Collection {
private var handle: smb2dir
init(_ path: String, on context: SMB2Context) throws {
let (_, handle) = try context.async_await(dataHandler: Parser.toOpaquePointer) { (context, cbPtr) -> Int32 in
let (_, handle) = try context.async_await(dataHandler: OpaquePointer.init) { (context, cbPtr) -> Int32 in
smb2_opendir_async(context, path, SMB2Context.generic_handler, cbPtr)
}
@ -33,7 +33,7 @@ final class SMB2Directory: Collection {
}
func makeIterator() -> AnyIterator<smb2dirent> {
let context = self.context.context
let context = self.context.unsafe
let handle = self.handle
smb2_rewinddir(context, handle)
return AnyIterator {
@ -50,7 +50,7 @@ final class SMB2Directory: Collection {
}
var count: Int {
let context = self.context.context
let context = self.context.unsafe
let handle = self.handle
let currentPos = smb2_telldir(context, handle)
defer {
@ -66,7 +66,7 @@ final class SMB2Directory: Collection {
}
subscript(position: Int) -> smb2dirent {
let context = self.context.context
let context = self.context.unsafe
let handle = self.handle
let currentPos = smb2_telldir(context, handle)
smb2_seekdir(context, handle, 0)

View File

@ -28,9 +28,9 @@ extension Optional where Wrapped: SMB2Context {
}
extension POSIXError {
static func throwIfError(_ result: Int32, description: String?) throws {
static func throwIfError<Number: SignedInteger>(_ result: Number, description: String?) throws {
guard result < 0 else { return }
let errno = -result
let errno = Int32(-result)
let errorDesc = description.map { "Error code \(errno): \($0)" }
throw POSIXError(.init(errno), description: errorDesc)
}

View File

@ -45,16 +45,16 @@ final class SMB2FileHandle {
try self.init(path, flags: O_RDWR | O_APPEND, on: context)
}
convenience init(path: String,
opLock: Int32 = SMB2_OPLOCK_LEVEL_NONE,
impersonation: Int32 = SMB2_IMPERSONATION_IMPERSONATION,
desiredAccess: Int32 = SMB2_FILE_READ_DATA | SMB2_FILE_WRITE_DATA | SMB2_FILE_APPEND_DATA | SMB2_FILE_READ_EA |
static func using(path: String,
opLock: Int32 = SMB2_OPLOCK_LEVEL_NONE,
impersonation: Int32 = SMB2_IMPERSONATION_IMPERSONATION,
desiredAccess: Int32 = SMB2_FILE_READ_DATA | SMB2_FILE_WRITE_DATA | SMB2_FILE_APPEND_DATA | SMB2_FILE_READ_EA |
SMB2_FILE_READ_ATTRIBUTES | SMB2_FILE_WRITE_EA | SMB2_FILE_WRITE_ATTRIBUTES | SMB2_READ_CONTROL | SMB2_SYNCHRONIZE,
fileAttributes: Int32 = 0,
shareAccess: Int32 = SMB2_FILE_SHARE_READ | SMB2_FILE_SHARE_WRITE | SMB2_FILE_SHARE_DELETE,
createDisposition: Int32 = SMB2_FILE_OPEN,
createOptions: Int32 = 0, on context: SMB2Context) throws {
let (_, file_id) = try context.async_await_pdu(dataHandler: Parser.toFileId) { (context, cbPtr) -> UnsafeMutablePointer<smb2_pdu>? in
fileAttributes: Int32 = 0,
shareAccess: Int32 = SMB2_FILE_SHARE_READ | SMB2_FILE_SHARE_WRITE | SMB2_FILE_SHARE_DELETE,
createDisposition: Int32 = SMB2_FILE_OPEN,
createOptions: Int32 = 0, on context: SMB2Context) throws -> SMB2FileHandle {
let (_, result) = try context.async_await_pdu(dataHandler: SMB2FileHandle.init) { (context, cbPtr) -> UnsafeMutablePointer<smb2_pdu>? in
return path.replacingOccurrences(of: "/", with: "\\").withCString { (path) in
var req = smb2_create_request()
req.requested_oplock_level = UInt8(opLock)
@ -69,19 +69,17 @@ final class SMB2FileHandle {
}
}
try self.init(fileDescriptor: file_id, on: context)
return result
}
init(fileDescriptor: smb2_file_id, on context: SMB2Context) throws {
self.context = context
var fileDescriptor = fileDescriptor
self.handle = try context.withThreadSafeContext { context in
smb2_fh_from_file_id(context, &fileDescriptor)
}
self.handle = smb2_fh_from_file_id(context.unsafe, &fileDescriptor)
}
private init(_ path: String, flags: Int32, on context: SMB2Context) throws {
let (_, handle) = try context.async_await(dataHandler: Parser.toOpaquePointer) { (context, cbPtr) -> Int32 in
let (_, handle) = try context.async_await(dataHandler: OpaquePointer.init) { (context, cbPtr) -> Int32 in
smb2_open_async(context, path, flags, SMB2Context.generic_handler, cbPtr)
}
self.context = context
@ -137,10 +135,8 @@ final class SMB2FileHandle {
@discardableResult
func lseek(offset: Int64, whence: SeekWhence) throws -> Int64 {
let handle = try self.handle.unwrap()
let result = smb2_lseek(context.context, handle, offset, whence.rawValue, nil)
if result < 0 {
try POSIXError.throwIfError(Int32(result), description: context.error)
}
let result = smb2_lseek(context.unsafe, handle, offset, whence.rawValue, nil)
try POSIXError.throwIfError(Int32(result), description: context.error)
return result
}
@ -213,8 +209,7 @@ final class SMB2FileHandle {
return try inputBuffer.withUnsafeMutableBytes { (buf) in
var req = smb2_ioctl_request(ctl_code: command.rawValue, file_id: fileId, input_count: UInt32(buf.count),
input: buf.baseAddress, flags: UInt32(SMB2_0_IOCTL_IS_FSCTL))
let outputHandler = Parser.ioctlOutputConverter(as: R.self)
return try context.async_await_pdu(dataHandler: outputHandler) {
return try context.async_await_pdu(dataHandler: R.init) {
(context, cbPtr) -> UnsafeMutablePointer<smb2_pdu>? in
smb2_cmd_ioctl_async(context, &req, SMB2Context.generic_handler, cbPtr)
}.data

View File

@ -9,44 +9,19 @@
import Foundation
import SMB2
struct Parser {
static func toVoid(_ context: UnsafeMutablePointer<smb2_context>, _ dataPtr: UnsafeMutableRawPointer?) throws -> Void {
return
}
static func toString(_ context: UnsafeMutablePointer<smb2_context>, _ dataPtr: UnsafeMutableRawPointer?) throws -> String {
return try String(cString: dataPtr.unwrap().assumingMemoryBound(to: Int8.self))
}
static func toSMB2Shares(_ context: UnsafeMutablePointer<smb2_context>, _ dataPtr: UnsafeMutableRawPointer?) throws -> [SMB2Share] {
defer { smb2_free_data(context, dataPtr) }
let result = try dataPtr.unwrap().assumingMemoryBound(to: srvsvc_netshareenumall_rep.self).pointee
return .init(result.ctr.pointee.ctr1)
}
static func toOpaquePointer(_ context: UnsafeMutablePointer<smb2_context>, _ dataPtr: UnsafeMutableRawPointer?) throws -> OpaquePointer {
return try OpaquePointer(dataPtr.unwrap())
}
static func toFileId(_ context: UnsafeMutablePointer<smb2_context>, _ dataPtr: UnsafeMutableRawPointer?) throws -> smb2_file_id {
return try dataPtr.unwrap().assumingMemoryBound(to: smb2_create_reply.self).pointee.file_id
}
static func ioctlOutputConverter<R: DataInitializable>(as: R.Type) ->
((_ context: UnsafeMutablePointer<smb2_context>, _ dataPtr: UnsafeMutableRawPointer?) throws -> R) {
return { context, dataPtr in
let reply = try dataPtr.unwrap().assumingMemoryBound(to: smb2_ioctl_reply.self).pointee
guard reply.output_count > 0, let output = reply.output else {
return try .empty()
}
defer { smb2_free_data(context, output) }
let data = Data(bytes: output, count: Int(reply.output_count))
return try R(data: data)
}
extension String {
init(_ context: SMB2Context, _ dataPtr: UnsafeMutableRawPointer?) throws {
self = try String(cString: dataPtr.unwrap().assumingMemoryBound(to: Int8.self))
}
}
extension Array where Element == SMB2Share {
init(_ context: SMB2Context, _ dataPtr: UnsafeMutableRawPointer?) throws {
defer { smb2_free_data(context.unsafe, dataPtr) }
let result = try dataPtr.unwrap().assumingMemoryBound(to: srvsvc_netshareenumall_rep.self).pointee
self = Array(result.ctr.pointee.ctr1)
}
init(_ ctr1: srvsvc_netsharectr1) {
self = [srvsvc_netshareinfo1](UnsafeBufferPointer(start: ctr1.array, count: Int(ctr1.count))).map {
SMB2Share(name: .init(cString: $0.name),
@ -55,3 +30,29 @@ extension Array where Element == SMB2Share {
}
}
}
extension OpaquePointer {
init(_ context: SMB2Context, _ dataPtr: UnsafeMutableRawPointer?) throws {
self = try OpaquePointer(dataPtr.unwrap())
}
}
extension SMB2FileHandle {
convenience init(_ context: SMB2Context, _ dataPtr: UnsafeMutableRawPointer?) throws {
let fileId = try dataPtr.unwrap().assumingMemoryBound(to: smb2_create_reply.self).pointee.file_id
try self.init(fileDescriptor: fileId, on: context)
}
}
extension DataInitializable {
init(_ context: SMB2Context, _ dataPtr: UnsafeMutableRawPointer?) throws {
let reply = try dataPtr.unwrap().assumingMemoryBound(to: smb2_ioctl_reply.self).pointee
guard reply.output_count > 0, let output = reply.output else {
self = try Self.empty()
return
}
defer { smb2_free_data(context.unsafe, output) }
let data = Data(bytes: output, count: Int(reply.output_count))
self = try Self(data: data)
}
}

View File

@ -116,7 +116,7 @@ class AMSMB2Tests: XCTestCase {
func testShareEnum() {
let expectation = self.expectation(description: #function)
expectation.expectedFulfillmentCount = 3
expectation.expectedFulfillmentCount = 2
let smb = AMSMB2(url: server, credential: credential)!
smb.listShares { result in
@ -144,6 +144,7 @@ class AMSMB2Tests: XCTestCase {
}
}
expectation.expectedFulfillmentCount += 1
smb._swift_listShares { result in
switch result {
case .success(let value):